mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39    ColumnName, ColumnType, Datum, Diff, GlobalId, IntoRowIterator, RelationType, Row, RowIterator,
40    ScalarType,
41};
42use proptest::prelude::{Arbitrary, BoxedStrategy, any};
43use proptest::strategy::{Strategy, Union};
44use proptest_derive::Arbitrary;
45use serde::{Deserialize, Serialize};
46
47use crate::Id::Local;
48use crate::explain::{HumanizedExpr, HumanizerMode};
49use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
50use crate::row::{RowCollection, SortedRowCollectionIter};
51use crate::visit::{Visit, VisitChildren};
52use crate::{
53    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
54    func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61include!(concat!(env!("OUT_DIR"), "/mz_expr.relation.rs"));
62
63/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
64///
65/// The recursion limit must be large enough to accommodate for the linear representation
66/// of some pathological but frequently occurring query fragments.
67///
68/// For example, in MIR we could have long chains of
69/// - (1) `Let` bindings,
70/// - (2) `CallBinary` calls with associative functions such as `+`
71///
72/// Until we fix those, we need to stick with the larger recursion limit.
73pub const RECURSION_LIMIT: usize = 2048;
74
75/// A trait for types that describe how to build a collection.
76pub trait CollectionPlan {
77    /// Collects the set of global identifiers from dataflows referenced in Get.
78    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
79
80    /// Returns the set of global identifiers from dataflows referenced in Get.
81    ///
82    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
83    fn depends_on(&self) -> BTreeSet<GlobalId> {
84        let mut out = BTreeSet::new();
85        self.depends_on_into(&mut out);
86        out
87    }
88}
89
90/// An abstract syntax tree which defines a collection.
91///
92/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
93/// written generically enough to avoid run-time compilation work.
94///
95/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
96/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
97/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
98/// one. This is because the reason for the manual implementation is not to change the semantics
99/// from the derived one, but to avoid stack overflows.
100#[allow(clippy::derived_hash_with_manual_eq)]
101#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
102pub enum MirRelationExpr {
103    /// A constant relation containing specified rows.
104    ///
105    /// The runtime memory footprint of this operator is zero.
106    ///
107    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
108    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
109    /// constant folding doesn't remove `ArrangeBy`s.
110    Constant {
111        /// Rows of the constant collection and their multiplicities.
112        rows: Result<Vec<(Row, Diff)>, EvalError>,
113        /// Schema of the collection.
114        typ: RelationType,
115    },
116    /// Get an existing dataflow.
117    ///
118    /// The runtime memory footprint of this operator is zero.
119    Get {
120        /// The identifier for the collection to load.
121        #[mzreflect(ignore)]
122        id: Id,
123        /// Schema of the collection.
124        typ: RelationType,
125        /// If this is a global Get, this will indicate whether we are going to read from Persist or
126        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
127        /// how downstream dataflow operations will use this index is also recorded. This is filled
128        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
129        /// lowering to LIR, but is used only by EXPLAIN.
130        #[mzreflect(ignore)]
131        access_strategy: AccessStrategy,
132    },
133    /// Introduce a temporary dataflow.
134    ///
135    /// The runtime memory footprint of this operator is zero.
136    Let {
137        /// The identifier to be used in `Get` variants to retrieve `value`.
138        #[mzreflect(ignore)]
139        id: LocalId,
140        /// The collection to be bound to `id`.
141        value: Box<MirRelationExpr>,
142        /// The result of the `Let`, evaluated with `id` bound to `value`.
143        body: Box<MirRelationExpr>,
144    },
145    /// Introduce mutually recursive bindings.
146    ///
147    /// Each `LocalId` is immediately bound to an initially empty  collection
148    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
149    /// binding is evaluated using the current contents of each other binding,
150    /// and is refreshed to contain the new evaluation. This process continues
151    /// through all bindings, and repeats as long as changes continue to occur.
152    ///
153    /// The resulting value of the expression is `body` evaluated once in the
154    /// context of the final iterates.
155    ///
156    /// A zero-binding instance can be replaced by `body`.
157    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
158    ///
159    /// The runtime memory footprint of this operator is zero.
160    LetRec {
161        /// The identifiers to be used in `Get` variants to retrieve each `value`.
162        #[mzreflect(ignore)]
163        ids: Vec<LocalId>,
164        /// The collections to be bound to each `id`.
165        values: Vec<MirRelationExpr>,
166        /// Maximum number of iterations, after which we should artificially force a fixpoint.
167        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
168        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
169        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
170        #[mzreflect(ignore)]
171        limits: Vec<Option<LetRecLimit>>,
172        /// The result of the `Let`, evaluated with `id` bound to `value`.
173        body: Box<MirRelationExpr>,
174    },
175    /// Project out some columns from a dataflow
176    ///
177    /// The runtime memory footprint of this operator is zero.
178    Project {
179        /// The source collection.
180        input: Box<MirRelationExpr>,
181        /// Indices of columns to retain.
182        outputs: Vec<usize>,
183    },
184    /// Append new columns to a dataflow
185    ///
186    /// The runtime memory footprint of this operator is zero.
187    Map {
188        /// The source collection.
189        input: Box<MirRelationExpr>,
190        /// Expressions which determine values to append to each row.
191        /// An expression may refer to columns in `input` or
192        /// expressions defined earlier in the vector
193        scalars: Vec<MirScalarExpr>,
194    },
195    /// Like Map, but yields zero-or-more output rows per input row
196    ///
197    /// The runtime memory footprint of this operator is zero.
198    FlatMap {
199        /// The source collection
200        input: Box<MirRelationExpr>,
201        /// The table func to apply
202        func: TableFunc,
203        /// The argument to the table func
204        exprs: Vec<MirScalarExpr>,
205    },
206    /// Keep rows from a dataflow where all the predicates are true
207    ///
208    /// The runtime memory footprint of this operator is zero.
209    Filter {
210        /// The source collection.
211        input: Box<MirRelationExpr>,
212        /// Predicates, each of which must be true.
213        predicates: Vec<MirScalarExpr>,
214    },
215    /// Join several collections, where some columns must be equal.
216    ///
217    /// For further details consult the documentation for [`MirRelationExpr::join`].
218    ///
219    /// The runtime memory footprint of this operator can be proportional to
220    /// the sizes of all inputs and the size of all joins of prefixes.
221    /// This may be reduced due to arrangements available at rendering time.
222    Join {
223        /// A sequence of input relations.
224        inputs: Vec<MirRelationExpr>,
225        /// A sequence of equivalence classes of expressions on the cross product of inputs.
226        ///
227        /// Each equivalence class is a list of scalar expressions, where for each class the
228        /// intended interpretation is that all evaluated expressions should be equal.
229        ///
230        /// Each scalar expression is to be evaluated over the cross-product of all records
231        /// from all inputs. In many cases this may just be column selection from specific
232        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
233        /// from multiple inputs, or just constant literals).
234        equivalences: Vec<Vec<MirScalarExpr>>,
235        /// Join implementation information.
236        #[serde(default)]
237        implementation: JoinImplementation,
238    },
239    /// Group a dataflow by some columns and aggregate over each group
240    ///
241    /// The runtime memory footprint of this operator is at most proportional to the
242    /// number of distinct records in the input and output. The actual requirements
243    /// can be less: the number of distinct inputs to each aggregate, summed across
244    /// each aggregate, plus the output size. For more details consult the code that
245    /// builds the associated dataflow.
246    Reduce {
247        /// The source collection.
248        input: Box<MirRelationExpr>,
249        /// Column indices used to form groups.
250        group_key: Vec<MirScalarExpr>,
251        /// Expressions which determine values to append to each row, after the group keys.
252        aggregates: Vec<AggregateExpr>,
253        /// True iff the input is known to monotonically increase (only addition of records).
254        #[serde(default)]
255        monotonic: bool,
256        /// User hint: expected number of values per group key. Used to optimize physical rendering.
257        #[serde(default)]
258        expected_group_size: Option<u64>,
259    },
260    /// Groups and orders within each group, limiting output.
261    ///
262    /// The runtime memory footprint of this operator is proportional to its input and output.
263    TopK {
264        /// The source collection.
265        input: Box<MirRelationExpr>,
266        /// Column indices used to form groups.
267        group_key: Vec<usize>,
268        /// Column indices used to order rows within groups.
269        order_key: Vec<ColumnOrder>,
270        /// Number of records to retain
271        #[serde(default)]
272        limit: Option<MirScalarExpr>,
273        /// Number of records to skip
274        #[serde(default)]
275        offset: usize,
276        /// True iff the input is known to monotonically increase (only addition of records).
277        #[serde(default)]
278        monotonic: bool,
279        /// User-supplied hint: how many rows will have the same group key.
280        #[serde(default)]
281        expected_group_size: Option<u64>,
282    },
283    /// Return a dataflow where the row counts are negated
284    ///
285    /// The runtime memory footprint of this operator is zero.
286    Negate {
287        /// The source collection.
288        input: Box<MirRelationExpr>,
289    },
290    /// Keep rows from a dataflow where the row counts are positive
291    ///
292    /// The runtime memory footprint of this operator is proportional to its input and output.
293    Threshold {
294        /// The source collection.
295        input: Box<MirRelationExpr>,
296    },
297    /// Adds the frequencies of elements in contained sets.
298    ///
299    /// The runtime memory footprint of this operator is zero.
300    Union {
301        /// A source collection.
302        base: Box<MirRelationExpr>,
303        /// Source collections to union.
304        inputs: Vec<MirRelationExpr>,
305    },
306    /// Technically a no-op. Used to render an index. Will be used to optimize queries
307    /// on finer grain. Each `keys` item represents a different index that should be
308    /// produced from the `keys`.
309    ///
310    /// The runtime memory footprint of this operator is proportional to its input.
311    ArrangeBy {
312        /// The source collection
313        input: Box<MirRelationExpr>,
314        /// Columns to arrange `input` by, in order of decreasing primacy
315        keys: Vec<Vec<MirScalarExpr>>,
316    },
317}
318
319impl PartialEq for MirRelationExpr {
320    fn eq(&self, other: &Self) -> bool {
321        // Capture the result and test it wrt `Ord` implementation in test environments.
322        let result = structured_diff::MreDiff::new(self, other).next().is_none();
323        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
324        result
325    }
326}
327impl Eq for MirRelationExpr {}
328
329impl Arbitrary for MirRelationExpr {
330    type Parameters = ();
331
332    fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
333        const VEC_LEN: usize = 2;
334
335        // A strategy for generating the leaf cases.
336        let leaf = Union::new([
337            (
338                any::<Result<Vec<(Row, Diff)>, EvalError>>(),
339                any::<RelationType>(),
340            )
341                .prop_map(|(rows, typ)| MirRelationExpr::Constant { rows, typ })
342                .boxed(),
343            (any::<Id>(), any::<RelationType>(), any::<AccessStrategy>())
344                .prop_map(|(id, typ, access_strategy)| MirRelationExpr::Get {
345                    id,
346                    typ,
347                    access_strategy,
348                })
349                .boxed(),
350        ])
351        .boxed();
352
353        leaf.prop_recursive(2, 2, 2, |inner| {
354            Union::new([
355                // Let
356                (any::<LocalId>(), inner.clone(), inner.clone())
357                    .prop_map(|(id, value, body)| MirRelationExpr::Let {
358                        id,
359                        value: Box::new(value),
360                        body: Box::new(body),
361                    })
362                    .boxed(),
363                // LetRec
364                (
365                    any::<Vec<LocalId>>(),
366                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
367                    any::<Vec<Option<LetRecLimit>>>(),
368                    inner.clone(),
369                )
370                    .prop_map(|(ids, values, limits, body)| MirRelationExpr::LetRec {
371                        ids,
372                        values,
373                        limits,
374                        body: Box::new(body),
375                    })
376                    .boxed(),
377                // Project
378                (inner.clone(), any::<Vec<usize>>())
379                    .prop_map(|(input, outputs)| MirRelationExpr::Project {
380                        input: Box::new(input),
381                        outputs,
382                    })
383                    .boxed(),
384                // Map
385                (inner.clone(), any::<Vec<MirScalarExpr>>())
386                    .prop_map(|(input, scalars)| MirRelationExpr::Map {
387                        input: Box::new(input),
388                        scalars,
389                    })
390                    .boxed(),
391                // FlatMap
392                (
393                    inner.clone(),
394                    any::<TableFunc>(),
395                    any::<Vec<MirScalarExpr>>(),
396                )
397                    .prop_map(|(input, func, exprs)| MirRelationExpr::FlatMap {
398                        input: Box::new(input),
399                        func,
400                        exprs,
401                    })
402                    .boxed(),
403                // Filter
404                (inner.clone(), any::<Vec<MirScalarExpr>>())
405                    .prop_map(|(input, predicates)| MirRelationExpr::Filter {
406                        input: Box::new(input),
407                        predicates,
408                    })
409                    .boxed(),
410                // Join
411                (
412                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
413                    any::<Vec<Vec<MirScalarExpr>>>(),
414                    any::<JoinImplementation>(),
415                )
416                    .prop_map(
417                        |(inputs, equivalences, implementation)| MirRelationExpr::Join {
418                            inputs,
419                            equivalences,
420                            implementation,
421                        },
422                    )
423                    .boxed(),
424                // Reduce
425                (
426                    inner.clone(),
427                    any::<Vec<MirScalarExpr>>(),
428                    any::<Vec<AggregateExpr>>(),
429                    any::<bool>(),
430                    any::<Option<u64>>(),
431                )
432                    .prop_map(
433                        |(input, group_key, aggregates, monotonic, expected_group_size)| {
434                            MirRelationExpr::Reduce {
435                                input: Box::new(input),
436                                group_key,
437                                aggregates,
438                                monotonic,
439                                expected_group_size,
440                            }
441                        },
442                    )
443                    .boxed(),
444                // TopK
445                (
446                    inner.clone(),
447                    any::<Vec<usize>>(),
448                    any::<Vec<ColumnOrder>>(),
449                    any::<Option<MirScalarExpr>>(),
450                    any::<usize>(),
451                    any::<bool>(),
452                    any::<Option<u64>>(),
453                )
454                    .prop_map(
455                        |(
456                            input,
457                            group_key,
458                            order_key,
459                            limit,
460                            offset,
461                            monotonic,
462                            expected_group_size,
463                        )| MirRelationExpr::TopK {
464                            input: Box::new(input),
465                            group_key,
466                            order_key,
467                            limit,
468                            offset,
469                            monotonic,
470                            expected_group_size,
471                        },
472                    )
473                    .boxed(),
474                // Negate
475                inner
476                    .clone()
477                    .prop_map(|input| MirRelationExpr::Negate {
478                        input: Box::new(input),
479                    })
480                    .boxed(),
481                // Threshold
482                inner
483                    .clone()
484                    .prop_map(|input| MirRelationExpr::Threshold {
485                        input: Box::new(input),
486                    })
487                    .boxed(),
488                // Union
489                (
490                    inner.clone(),
491                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
492                )
493                    .prop_map(|(base, inputs)| MirRelationExpr::Union {
494                        base: Box::new(base),
495                        inputs,
496                    })
497                    .boxed(),
498                // ArrangeBy
499                (inner.clone(), any::<Vec<Vec<MirScalarExpr>>>())
500                    .prop_map(|(input, keys)| MirRelationExpr::ArrangeBy {
501                        input: Box::new(input),
502                        keys,
503                    })
504                    .boxed(),
505            ])
506        })
507        .boxed()
508    }
509
510    type Strategy = BoxedStrategy<Self>;
511}
512
513impl MirRelationExpr {
514    /// Reports the schema of the relation.
515    ///
516    /// This method determines the type through recursive traversal of the
517    /// relation expression, drawing from the types of base collections.
518    /// As such, this is not an especially cheap method, and should be used
519    /// judiciously.
520    ///
521    /// The relation type is computed incrementally with a recursive post-order
522    /// traversal, that accumulates the input types for the relations yet to be
523    /// visited in `type_stack`.
524    pub fn typ(&self) -> RelationType {
525        let mut type_stack = Vec::new();
526        #[allow(deprecated)]
527        self.visit_pre_post_nolimit(
528            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
529                match &e {
530                    MirRelationExpr::Let { body, .. } => {
531                        // Do not traverse the value sub-graph, since it's not relevant for
532                        // determining the relation type of Let operators.
533                        Some(vec![&*body])
534                    }
535                    MirRelationExpr::LetRec { body, .. } => {
536                        // Do not traverse the value sub-graph, since it's not relevant for
537                        // determining the relation type of Let operators.
538                        Some(vec![&*body])
539                    }
540                    _ => None,
541                }
542            },
543            &mut |e: &MirRelationExpr| {
544                match e {
545                    MirRelationExpr::Let { .. } => {
546                        let body_typ = type_stack.pop().unwrap();
547                        // Insert a dummy relation type for the value, since `typ_with_input_types`
548                        // won't look at it, but expects the relation type of the body to be second.
549                        type_stack.push(RelationType::empty());
550                        type_stack.push(body_typ);
551                    }
552                    MirRelationExpr::LetRec { values, .. } => {
553                        let body_typ = type_stack.pop().unwrap();
554                        // Insert dummy relation types for the values, since `typ_with_input_types`
555                        // won't look at them, but expects the relation type of the body to be last.
556                        type_stack
557                            .extend(std::iter::repeat(RelationType::empty()).take(values.len()));
558                        type_stack.push(body_typ);
559                    }
560                    _ => {}
561                }
562                let num_inputs = e.num_inputs();
563                let relation_type =
564                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
565                type_stack.truncate(type_stack.len() - num_inputs);
566                type_stack.push(relation_type);
567            },
568        );
569        assert_eq!(type_stack.len(), 1);
570        type_stack.pop().unwrap()
571    }
572
573    /// Reports the schema of the relation given the schema of the input relations.
574    ///
575    /// `input_types` is required to contain the schemas for the input relations of
576    /// the current relation in the same order as they are visited by `try_visit_children`
577    /// method, even though not all may be used for computing the schema of the
578    /// current relation. For example, `Let` expects two input types, one for the
579    /// value relation and one for the body, in that order, but only the one for the
580    /// body is used to determine the type of the `Let` relation.
581    ///
582    /// It is meant to be used during post-order traversals to compute relation
583    /// schemas incrementally.
584    pub fn typ_with_input_types(&self, input_types: &[RelationType]) -> RelationType {
585        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
586        let unique_keys = self.keys_with_input_keys(
587            input_types.iter().map(|i| i.arity()),
588            input_types.iter().map(|i| &i.keys),
589        );
590        RelationType::new(column_types).with_keys(unique_keys)
591    }
592
593    /// Reports the column types of the relation given the column types of the
594    /// input relations.
595    ///
596    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
597    /// variant is returned.
598    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ColumnType>
599    where
600        I: Iterator<Item = &'a Vec<ColumnType>>,
601    {
602        match self.try_col_with_input_cols(input_types) {
603            Ok(col_types) => col_types,
604            Err(err) => panic!("{err}"),
605        }
606    }
607
608    /// Reports the column types of the relation given the column types of the input relations.
609    ///
610    /// `input_types` is required to contain the column types for the input relations of
611    /// the current relation in the same order as they are visited by `try_visit_children`
612    /// method, even though not all may be used for computing the schema of the
613    /// current relation. For example, `Let` expects two input types, one for the
614    /// value relation and one for the body, in that order, but only the one for the
615    /// body is used to determine the type of the `Let` relation.
616    ///
617    /// It is meant to be used during post-order traversals to compute column types
618    /// incrementally.
619    pub fn try_col_with_input_cols<'a, I>(
620        &self,
621        mut input_types: I,
622    ) -> Result<Vec<ColumnType>, String>
623    where
624        I: Iterator<Item = &'a Vec<ColumnType>>,
625    {
626        use MirRelationExpr::*;
627
628        let col_types = match self {
629            Constant { rows, typ } => {
630                let mut col_types = typ.column_types.clone();
631                let mut seen_null = vec![false; typ.arity()];
632                if let Ok(rows) = rows {
633                    for (row, _diff) in rows {
634                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
635                            if datum.is_null() {
636                                seen_null[i] = true;
637                            }
638                        }
639                    }
640                }
641                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
642                    if !seen_null {
643                        col_types[i].nullable = false;
644                    } else {
645                        assert!(col_types[i].nullable);
646                    }
647                }
648                col_types
649            }
650            Get { typ, .. } => typ.column_types.clone(),
651            Project { outputs, .. } => {
652                let input = input_types.next().unwrap();
653                outputs.iter().map(|&i| input[i].clone()).collect()
654            }
655            Map { scalars, .. } => {
656                let mut result = input_types.next().unwrap().clone();
657                for scalar in scalars.iter() {
658                    result.push(scalar.typ(&result))
659                }
660                result
661            }
662            FlatMap { func, .. } => {
663                let mut result = input_types.next().unwrap().clone();
664                result.extend(func.output_type().column_types);
665                result
666            }
667            Filter { predicates, .. } => {
668                let mut result = input_types.next().unwrap().clone();
669
670                // Set as nonnull any columns where null values would cause
671                // any predicate to evaluate to null.
672                for column in non_nullable_columns(predicates) {
673                    result[column].nullable = false;
674                }
675                result
676            }
677            Join { equivalences, .. } => {
678                // Concatenate input column types
679                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
680                // In an equivalence class, if any column is non-null, then make all non-null
681                for equivalence in equivalences {
682                    let col_inds = equivalence
683                        .iter()
684                        .filter_map(|expr| match expr {
685                            MirScalarExpr::Column(col) => Some(*col),
686                            _ => None,
687                        })
688                        .collect_vec();
689                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
690                        for i in col_inds {
691                            types.get_mut(i).unwrap().nullable = false;
692                        }
693                    }
694                }
695                types
696            }
697            Reduce {
698                group_key,
699                aggregates,
700                ..
701            } => {
702                let input = input_types.next().unwrap();
703                group_key
704                    .iter()
705                    .map(|e| e.typ(input))
706                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
707                    .collect()
708            }
709            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
710                input_types.next().unwrap().clone()
711            }
712            Let { .. } => {
713                // skip over the input types for `value`.
714                input_types.nth(1).unwrap().clone()
715            }
716            LetRec { values, .. } => {
717                // skip over the input types for `values`.
718                input_types.nth(values.len()).unwrap().clone()
719            }
720            Union { .. } => {
721                let mut result = input_types.next().unwrap().clone();
722                for input_col_types in input_types {
723                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
724                        *base_col = base_col
725                            .union(col)
726                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
727                    }
728                }
729                result
730            }
731        };
732
733        Ok(col_types)
734    }
735
736    /// Reports the unique keys of the relation given the arities and the unique
737    /// keys of the input relations.
738    ///
739    /// `input_arities` and `input_keys` are required to contain the
740    /// corresponding info for the input relations of
741    /// the current relation in the same order as they are visited by `try_visit_children`
742    /// method, even though not all may be used for computing the schema of the
743    /// current relation. For example, `Let` expects two input types, one for the
744    /// value relation and one for the body, in that order, but only the one for the
745    /// body is used to determine the type of the `Let` relation.
746    ///
747    /// It is meant to be used during post-order traversals to compute unique keys
748    /// incrementally.
749    pub fn keys_with_input_keys<'a, I, J>(
750        &self,
751        mut input_arities: I,
752        mut input_keys: J,
753    ) -> Vec<Vec<usize>>
754    where
755        I: Iterator<Item = usize>,
756        J: Iterator<Item = &'a Vec<Vec<usize>>>,
757    {
758        use MirRelationExpr::*;
759
760        let mut keys = match self {
761            Constant {
762                rows: Ok(rows),
763                typ,
764            } => {
765                let n_cols = typ.arity();
766                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
767                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
768                for (row, diff) in rows {
769                    for (i, datum) in row.iter().enumerate() {
770                        if datum != Datum::Dummy {
771                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
772                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
773                                if is_dupe {
774                                    unique_values_per_col[i] = None;
775                                }
776                            }
777                        }
778                    }
779                }
780                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
781                    vec![vec![]]
782                } else {
783                    // XXX - Multi-column keys are not detected.
784                    typ.keys
785                        .iter()
786                        .cloned()
787                        .chain(
788                            unique_values_per_col
789                                .into_iter()
790                                .enumerate()
791                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
792                                .map(|(idx, _)| vec![idx]),
793                        )
794                        .collect()
795                }
796            }
797            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
798            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
799            Let { .. } => {
800                // skip over the unique keys for value
801                input_keys.nth(1).unwrap().clone()
802            }
803            LetRec { values, .. } => {
804                // skip over the unique keys for value
805                input_keys.nth(values.len()).unwrap().clone()
806            }
807            Project { outputs, .. } => {
808                let input = input_keys.next().unwrap();
809                input
810                    .iter()
811                    .filter_map(|key_set| {
812                        if key_set.iter().all(|k| outputs.contains(k)) {
813                            Some(
814                                key_set
815                                    .iter()
816                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
817                                    .collect(),
818                            )
819                        } else {
820                            None
821                        }
822                    })
823                    .collect()
824            }
825            Map { scalars, .. } => {
826                let mut remappings = Vec::new();
827                let arity = input_arities.next().unwrap();
828                for (column, scalar) in scalars.iter().enumerate() {
829                    // assess whether the scalar preserves uniqueness,
830                    // and could participate in a key!
831
832                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
833                        match expr {
834                            MirScalarExpr::CallUnary { func, expr } => {
835                                if func.preserves_uniqueness() {
836                                    uniqueness(expr)
837                                } else {
838                                    None
839                                }
840                            }
841                            MirScalarExpr::Column(c) => Some(*c),
842                            _ => None,
843                        }
844                    }
845
846                    if let Some(c) = uniqueness(scalar) {
847                        remappings.push((c, column + arity));
848                    }
849                }
850
851                let mut result = input_keys.next().unwrap().clone();
852                let mut new_keys = Vec::new();
853                // Any column in `remappings` could be replaced in a key
854                // by the corresponding c. This could lead to combinatorial
855                // explosion using our current representation, so we wont
856                // do that. Instead, we'll handle the case of one remapping.
857                if remappings.len() == 1 {
858                    let (old, new) = remappings.pop().unwrap();
859                    for key in &result {
860                        if key.contains(&old) {
861                            let mut new_key: Vec<usize> =
862                                key.iter().cloned().filter(|k| k != &old).collect();
863                            new_key.push(new);
864                            new_key.sort_unstable();
865                            new_keys.push(new_key);
866                        }
867                    }
868                    result.append(&mut new_keys);
869                }
870                result
871            }
872            FlatMap { .. } => {
873                // FlatMap can add duplicate rows, so input keys are no longer
874                // valid
875                vec![]
876            }
877            Negate { .. } => {
878                // Although negate may have distinct records for each key,
879                // the multiplicity is -1 rather than 1. This breaks many
880                // of the optimization uses of "keys".
881                vec![]
882            }
883            Filter { predicates, .. } => {
884                // A filter inherits the keys of its input unless the filters
885                // have reduced the input to a single row, in which case the
886                // keys of the input are `()`.
887                let mut input = input_keys.next().unwrap().clone();
888
889                if !input.is_empty() {
890                    // Track columns equated to literals, which we can prune.
891                    let mut cols_equal_to_literal = BTreeSet::new();
892
893                    // Perform union find on `col1 = col2` to establish
894                    // connected components of equated columns. Absent any
895                    // equalities, this will be `0 .. #c` (where #c is the
896                    // greatest column referenced by a predicate), but each
897                    // equality will orient the root of the greater to the root
898                    // of the lesser.
899                    let mut union_find = Vec::new();
900
901                    for expr in predicates.iter() {
902                        if let MirScalarExpr::CallBinary {
903                            func: crate::BinaryFunc::Eq,
904                            expr1,
905                            expr2,
906                        } = expr
907                        {
908                            if let MirScalarExpr::Column(c) = &**expr1 {
909                                if expr2.is_literal_ok() {
910                                    cols_equal_to_literal.insert(c);
911                                }
912                            }
913                            if let MirScalarExpr::Column(c) = &**expr2 {
914                                if expr1.is_literal_ok() {
915                                    cols_equal_to_literal.insert(c);
916                                }
917                            }
918                            // Perform union-find to equate columns.
919                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
920                                if c1 != c2 {
921                                    // Ensure union_find has entries up to
922                                    // max(c1, c2) by filling up missing
923                                    // positions with identity mappings.
924                                    while union_find.len() <= std::cmp::max(c1, c2) {
925                                        union_find.push(union_find.len());
926                                    }
927                                    let mut r1 = c1; // Find the representative column of [c1].
928                                    while r1 != union_find[r1] {
929                                        assert!(union_find[r1] < r1);
930                                        r1 = union_find[r1];
931                                    }
932                                    let mut r2 = c2; // Find the representative column of [c2].
933                                    while r2 != union_find[r2] {
934                                        assert!(union_find[r2] < r2);
935                                        r2 = union_find[r2];
936                                    }
937                                    // Union [c1] and [c2] by pointing the
938                                    // larger to the smaller representative (we
939                                    // update the remaining equivalence class
940                                    // members only once after this for-loop).
941                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
942                                }
943                            }
944                        }
945                    }
946
947                    // Complete union-find by pointing each element at its representative column.
948                    for i in 0..union_find.len() {
949                        // Iteration not required, as each prior already references the right column.
950                        union_find[i] = union_find[union_find[i]];
951                    }
952
953                    // Remove columns bound to literals, and remap columns equated to earlier columns.
954                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
955                    for key_set in &mut input {
956                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
957                        for col in key_set.iter_mut() {
958                            if let Some(equiv) = union_find.get(*col) {
959                                *col = *equiv;
960                            }
961                        }
962                        key_set.sort();
963                        key_set.dedup();
964                    }
965                    input.sort();
966                    input.dedup();
967
968                    // Expand out each key to each of its equivalent forms.
969                    // Each instance of `col` can be replaced by any equivalent column.
970                    // This has the potential to result in exponentially sized number of unique keys,
971                    // and in the future we should probably maintain unique keys modulo equivalence.
972
973                    // First, compute an inverse map from each representative
974                    // column `sub` to all other equivalent columns `col`.
975                    let mut subs = Vec::new();
976                    for (col, sub) in union_find.iter().enumerate() {
977                        if *sub != col {
978                            assert!(*sub < col);
979                            while subs.len() <= *sub {
980                                subs.push(Vec::new());
981                            }
982                            subs[*sub].push(col);
983                        }
984                    }
985                    // For each column, substitute for it in each occurrence.
986                    let mut to_add = Vec::new();
987                    for (col, subs) in subs.iter().enumerate() {
988                        if !subs.is_empty() {
989                            for key_set in input.iter() {
990                                if key_set.contains(&col) {
991                                    let mut to_extend = key_set.clone();
992                                    to_extend.retain(|c| c != &col);
993                                    for sub in subs {
994                                        to_extend.push(*sub);
995                                        to_add.push(to_extend.clone());
996                                        to_extend.pop();
997                                    }
998                                }
999                            }
1000                        }
1001                        // No deduplication, as we cannot introduce duplicates.
1002                        input.append(&mut to_add);
1003                    }
1004                    for key_set in input.iter_mut() {
1005                        key_set.sort();
1006                        key_set.dedup();
1007                    }
1008                }
1009                input
1010            }
1011            Join { equivalences, .. } => {
1012                // It is important the `new_from_input_arities` constructor is
1013                // used. Otherwise, Materialize may potentially end up in an
1014                // infinite loop.
1015                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
1016
1017                input_mapper.global_keys(input_keys, equivalences)
1018            }
1019            Reduce { group_key, .. } => {
1020                // The group key should form a key, but we might already have
1021                // keys that are subsets of the group key, and should retain
1022                // those instead, if so.
1023                let mut result = Vec::new();
1024                for key_set in input_keys.next().unwrap() {
1025                    if key_set
1026                        .iter()
1027                        .all(|k| group_key.contains(&MirScalarExpr::Column(*k)))
1028                    {
1029                        result.push(
1030                            key_set
1031                                .iter()
1032                                .map(|i| {
1033                                    group_key
1034                                        .iter()
1035                                        .position(|k| k == &MirScalarExpr::Column(*i))
1036                                        .unwrap()
1037                                })
1038                                .collect::<Vec<_>>(),
1039                        );
1040                    }
1041                }
1042                if result.is_empty() {
1043                    result.push((0..group_key.len()).collect());
1044                }
1045                result
1046            }
1047            TopK {
1048                group_key, limit, ..
1049            } => {
1050                // If `limit` is `Some(1)` then the group key will become
1051                // a unique key, as there will be only one record with that key.
1052                let mut result = input_keys.next().unwrap().clone();
1053                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1054                    result.push(group_key.clone())
1055                }
1056                result
1057            }
1058            Union { base, inputs } => {
1059                // Generally, unions do not have any unique keys, because
1060                // each input might duplicate some. However, there is at
1061                // least one idiomatic structure that does preserve keys,
1062                // which results from SQL aggregations that must populate
1063                // absent records with default values. In that pattern,
1064                // the union of one GET with its negation, which has first
1065                // been subjected to a projection and map, we can remove
1066                // their influence on the key structure.
1067                //
1068                // If there are A, B, each with a unique `key` such that
1069                // we are looking at
1070                //
1071                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1072                //
1073                // Then we can report `key` as a unique key.
1074                //
1075                // TODO: make unique key structure an optimization analysis
1076                // rather than part of the type information.
1077                // TODO: perhaps ensure that (above) A.proj(key) is a
1078                // subset of B, as otherwise there are negative records
1079                // and who knows what is true (not expected, but again
1080                // who knows what the query plan might look like).
1081
1082                let arity = input_arities.next().unwrap();
1083                let (base_projection, base_with_project_stripped) =
1084                    if let MirRelationExpr::Project { input, outputs } = &**base {
1085                        (outputs.clone(), &**input)
1086                    } else {
1087                        // A input without a project is equivalent to an input
1088                        // with the project being all columns in the input in order.
1089                        ((0..arity).collect::<Vec<_>>(), &**base)
1090                    };
1091                let mut result = Vec::new();
1092                if let MirRelationExpr::Get {
1093                    id: first_id,
1094                    typ: _,
1095                    ..
1096                } = base_with_project_stripped
1097                {
1098                    if inputs.len() == 1 {
1099                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1100                            if let MirRelationExpr::Union { base, inputs } = &**input {
1101                                if inputs.len() == 1 {
1102                                    if let Some((input, outputs)) = base.is_negated_project() {
1103                                        if let MirRelationExpr::Get {
1104                                            id: second_id,
1105                                            typ: _,
1106                                            ..
1107                                        } = input
1108                                        {
1109                                            if first_id == second_id {
1110                                                result.extend(
1111                                                    input_keys
1112                                                        .next()
1113                                                        .unwrap()
1114                                                        .into_iter()
1115                                                        .filter(|key| {
1116                                                            key.iter().all(|c| {
1117                                                                outputs.get(*c) == Some(c)
1118                                                                    && base_projection.get(*c)
1119                                                                        == Some(c)
1120                                                            })
1121                                                        })
1122                                                        .cloned(),
1123                                                );
1124                                            }
1125                                        }
1126                                    }
1127                                }
1128                            }
1129                        }
1130                    }
1131                }
1132                // Important: do not inherit keys of either input, as not unique.
1133                result
1134            }
1135        };
1136        keys.sort();
1137        keys.dedup();
1138        keys
1139    }
1140
1141    /// The number of columns in the relation.
1142    ///
1143    /// This number is determined from the type, which is determined recursively
1144    /// at non-trivial cost.
1145    ///
1146    /// The arity is computed incrementally with a recursive post-order
1147    /// traversal, that accumulates the arities for the relations yet to be
1148    /// visited in `arity_stack`.
1149    pub fn arity(&self) -> usize {
1150        let mut arity_stack = Vec::new();
1151        #[allow(deprecated)]
1152        self.visit_pre_post_nolimit(
1153            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1154                match &e {
1155                    MirRelationExpr::Let { body, .. } => {
1156                        // Do not traverse the value sub-graph, since it's not relevant for
1157                        // determining the arity of Let operators.
1158                        Some(vec![&*body])
1159                    }
1160                    MirRelationExpr::LetRec { body, .. } => {
1161                        // Do not traverse the value sub-graph, since it's not relevant for
1162                        // determining the arity of Let operators.
1163                        Some(vec![&*body])
1164                    }
1165                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1166                        // No further traversal is required; these operators know their arity.
1167                        Some(Vec::new())
1168                    }
1169                    _ => None,
1170                }
1171            },
1172            &mut |e: &MirRelationExpr| {
1173                match &e {
1174                    MirRelationExpr::Let { .. } => {
1175                        let body_arity = arity_stack.pop().unwrap();
1176                        arity_stack.push(0);
1177                        arity_stack.push(body_arity);
1178                    }
1179                    MirRelationExpr::LetRec { values, .. } => {
1180                        let body_arity = arity_stack.pop().unwrap();
1181                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
1182                        arity_stack.push(body_arity);
1183                    }
1184                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1185                        arity_stack.push(0);
1186                    }
1187                    _ => {}
1188                }
1189                let num_inputs = e.num_inputs();
1190                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1191                let arity = e.arity_with_input_arities(input_arities);
1192                arity_stack.push(arity);
1193            },
1194        );
1195        assert_eq!(arity_stack.len(), 1);
1196        arity_stack.pop().unwrap()
1197    }
1198
1199    /// Reports the arity of the relation given the schema of the input relations.
1200    ///
1201    /// `input_arities` is required to contain the arities for the input relations of
1202    /// the current relation in the same order as they are visited by `try_visit_children`
1203    /// method, even though not all may be used for computing the schema of the
1204    /// current relation. For example, `Let` expects two input types, one for the
1205    /// value relation and one for the body, in that order, but only the one for the
1206    /// body is used to determine the type of the `Let` relation.
1207    ///
1208    /// It is meant to be used during post-order traversals to compute arities
1209    /// incrementally.
1210    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1211    where
1212        I: Iterator<Item = usize>,
1213    {
1214        use MirRelationExpr::*;
1215
1216        match self {
1217            Constant { rows: _, typ } => typ.arity(),
1218            Get { typ, .. } => typ.arity(),
1219            Let { .. } => {
1220                input_arities.next();
1221                input_arities.next().unwrap()
1222            }
1223            LetRec { values, .. } => {
1224                for _ in 0..values.len() {
1225                    input_arities.next();
1226                }
1227                input_arities.next().unwrap()
1228            }
1229            Project { outputs, .. } => outputs.len(),
1230            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1231            FlatMap { func, .. } => {
1232                input_arities.next().unwrap() + func.output_type().column_types.len()
1233            }
1234            Join { .. } => input_arities.sum(),
1235            Reduce {
1236                input: _,
1237                group_key,
1238                aggregates,
1239                ..
1240            } => group_key.len() + aggregates.len(),
1241            Filter { .. }
1242            | TopK { .. }
1243            | Negate { .. }
1244            | Threshold { .. }
1245            | Union { .. }
1246            | ArrangeBy { .. } => input_arities.next().unwrap(),
1247        }
1248    }
1249
1250    /// The number of child relations this relation has.
1251    pub fn num_inputs(&self) -> usize {
1252        let mut count = 0;
1253
1254        self.visit_children(|_| count += 1);
1255
1256        count
1257    }
1258
1259    /// Constructs a constant collection from specific rows and schema, where
1260    /// each row will have a multiplicity of one.
1261    pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
1262        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1263        MirRelationExpr::constant_diff(rows, typ)
1264    }
1265
1266    /// Constructs a constant collection from specific rows and schema, where
1267    /// each row can have an arbitrary multiplicity.
1268    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: RelationType) -> Self {
1269        for (row, _diff) in &rows {
1270            for (datum, column_typ) in row.iter().zip(typ.column_types.iter()) {
1271                assert!(
1272                    datum.is_instance_of(column_typ),
1273                    "Expected datum of type {:?}, got value {:?}",
1274                    column_typ,
1275                    datum
1276                );
1277            }
1278        }
1279        let rows = Ok(rows
1280            .into_iter()
1281            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1282            .collect());
1283        MirRelationExpr::Constant { rows, typ }
1284    }
1285
1286    /// If self is a constant, return the value and the type, otherwise `None`.
1287    /// Looks behind `ArrangeBy`s.
1288    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &RelationType)> {
1289        match self {
1290            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1291            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1292            _ => None,
1293        }
1294    }
1295
1296    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1297    /// Looks behind `ArrangeBy`s.
1298    pub fn as_const_mut(
1299        &mut self,
1300    ) -> Option<(&mut Result<Vec<(Row, Diff)>, EvalError>, &mut RelationType)> {
1301        match self {
1302            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1303            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1304            _ => None,
1305        }
1306    }
1307
1308    /// If self is a constant error, return the error, otherwise `None`.
1309    /// Looks behind `ArrangeBy`s.
1310    pub fn as_const_err(&self) -> Option<&EvalError> {
1311        match self {
1312            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1313            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1314            _ => None,
1315        }
1316    }
1317
1318    /// Checks if `self` is the single element collection with no columns.
1319    pub fn is_constant_singleton(&self) -> bool {
1320        if let Some((Ok(rows), typ)) = self.as_const() {
1321            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1322        } else {
1323            false
1324        }
1325    }
1326
1327    /// Constructs the expression for getting a local collection.
1328    pub fn local_get(id: LocalId, typ: RelationType) -> Self {
1329        MirRelationExpr::Get {
1330            id: Id::Local(id),
1331            typ,
1332            access_strategy: AccessStrategy::UnknownOrLocal,
1333        }
1334    }
1335
1336    /// Constructs the expression for getting a global collection
1337    pub fn global_get(id: GlobalId, typ: RelationType) -> Self {
1338        MirRelationExpr::Get {
1339            id: Id::Global(id),
1340            typ,
1341            access_strategy: AccessStrategy::UnknownOrLocal,
1342        }
1343    }
1344
1345    /// Retains only the columns specified by `output`.
1346    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1347        if let MirRelationExpr::Project {
1348            outputs: columns, ..
1349        } = &mut self
1350        {
1351            // Update `outputs` to reference base columns of `input`.
1352            for column in outputs.iter_mut() {
1353                *column = columns[*column];
1354            }
1355            *columns = outputs;
1356            self
1357        } else {
1358            MirRelationExpr::Project {
1359                input: Box::new(self),
1360                outputs,
1361            }
1362        }
1363    }
1364
1365    /// Append to each row the results of applying elements of `scalar`.
1366    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1367        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1368            s.extend(scalars);
1369            self
1370        } else if !scalars.is_empty() {
1371            MirRelationExpr::Map {
1372                input: Box::new(self),
1373                scalars,
1374            }
1375        } else {
1376            self
1377        }
1378    }
1379
1380    /// Append to each row a single `scalar`.
1381    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1382        self.map(vec![scalar])
1383    }
1384
1385    /// Like `map`, but yields zero-or-more output rows per input row
1386    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1387        MirRelationExpr::FlatMap {
1388            input: Box::new(self),
1389            func,
1390            exprs,
1391        }
1392    }
1393
1394    /// Retain only the rows satisfying each of several predicates.
1395    pub fn filter<I>(mut self, predicates: I) -> Self
1396    where
1397        I: IntoIterator<Item = MirScalarExpr>,
1398    {
1399        // Extract existing predicates
1400        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1401            self = *input;
1402            predicates
1403        } else {
1404            Vec::new()
1405        };
1406        // Normalize collection of predicates.
1407        new_predicates.extend(predicates);
1408        new_predicates.retain(|p| !p.is_literal_true());
1409        new_predicates.sort();
1410        new_predicates.dedup();
1411        // Introduce a `Filter` only if we have predicates.
1412        if !new_predicates.is_empty() {
1413            self = MirRelationExpr::Filter {
1414                input: Box::new(self),
1415                predicates: new_predicates,
1416            };
1417        }
1418
1419        self
1420    }
1421
1422    /// Form the Cartesian outer-product of rows in both inputs.
1423    pub fn product(mut self, right: Self) -> Self {
1424        if right.is_constant_singleton() {
1425            self
1426        } else if self.is_constant_singleton() {
1427            right
1428        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1429            inputs.push(right);
1430            self
1431        } else {
1432            MirRelationExpr::join(vec![self, right], vec![])
1433        }
1434    }
1435
1436    /// Performs a relational equijoin among the input collections.
1437    ///
1438    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1439    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1440    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1441    ///
1442    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1443    /// input collection and for each row examining its `column`th column.
1444    ///
1445    /// # Example
1446    ///
1447    /// ```rust
1448    /// use mz_repr::{Datum, ColumnType, RelationType, ScalarType};
1449    /// use mz_expr::MirRelationExpr;
1450    ///
1451    /// // A common schema for each input.
1452    /// let schema = RelationType::new(vec![
1453    ///     ScalarType::Int32.nullable(false),
1454    ///     ScalarType::Int32.nullable(false),
1455    /// ]);
1456    ///
1457    /// // the specific data are not important here.
1458    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1459    ///
1460    /// // Three collections that could have been different.
1461    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1462    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1463    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1464    ///
1465    /// // Join the three relations looking for triangles, like so.
1466    /// //
1467    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1468    /// let joined = MirRelationExpr::join(
1469    ///     vec![input0, input1, input2],
1470    ///     vec![
1471    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1472    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1473    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1474    ///     ],
1475    /// );
1476    ///
1477    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1478    /// // A projection resolves this and produces the correct output.
1479    /// let result = joined.project(vec![0, 1, 3]);
1480    /// ```
1481    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1482        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1483
1484        let equivalences = variables
1485            .into_iter()
1486            .map(|vs| {
1487                vs.into_iter()
1488                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::Column(c), r))
1489                    .collect::<Vec<_>>()
1490            })
1491            .collect::<Vec<_>>();
1492
1493        Self::join_scalars(inputs, equivalences)
1494    }
1495
1496    /// Constructs a join operator from inputs and required-equal scalar expressions.
1497    pub fn join_scalars(
1498        mut inputs: Vec<MirRelationExpr>,
1499        equivalences: Vec<Vec<MirScalarExpr>>,
1500    ) -> Self {
1501        // Remove all constant inputs that are the identity for join.
1502        // They neither introduce nor modify any column references.
1503        inputs.retain(|i| !i.is_constant_singleton());
1504        MirRelationExpr::Join {
1505            inputs,
1506            equivalences,
1507            implementation: JoinImplementation::Unimplemented,
1508        }
1509    }
1510
1511    /// Perform a key-wise reduction / aggregation.
1512    ///
1513    /// The `group_key` argument indicates columns in the input collection that should
1514    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1515    /// one output column in addition to the keys.
1516    pub fn reduce(
1517        self,
1518        group_key: Vec<usize>,
1519        aggregates: Vec<AggregateExpr>,
1520        expected_group_size: Option<u64>,
1521    ) -> Self {
1522        MirRelationExpr::Reduce {
1523            input: Box::new(self),
1524            group_key: group_key.into_iter().map(MirScalarExpr::Column).collect(),
1525            aggregates,
1526            monotonic: false,
1527            expected_group_size,
1528        }
1529    }
1530
1531    /// Perform a key-wise reduction order by and limit.
1532    ///
1533    /// The `group_key` argument indicates columns in the input collection that should
1534    /// be grouped, the `order_key` argument indicates columns that should be further
1535    /// used to order records within groups, and the `limit` argument constrains the
1536    /// total number of records that should be produced in each group.
1537    pub fn top_k(
1538        self,
1539        group_key: Vec<usize>,
1540        order_key: Vec<ColumnOrder>,
1541        limit: Option<MirScalarExpr>,
1542        offset: usize,
1543        expected_group_size: Option<u64>,
1544    ) -> Self {
1545        MirRelationExpr::TopK {
1546            input: Box::new(self),
1547            group_key,
1548            order_key,
1549            limit,
1550            offset,
1551            expected_group_size,
1552            monotonic: false,
1553        }
1554    }
1555
1556    /// Negates the occurrences of each row.
1557    pub fn negate(self) -> Self {
1558        if let MirRelationExpr::Negate { input } = self {
1559            *input
1560        } else {
1561            MirRelationExpr::Negate {
1562                input: Box::new(self),
1563            }
1564        }
1565    }
1566
1567    /// Removes all but the first occurrence of each row.
1568    pub fn distinct(self) -> Self {
1569        let arity = self.arity();
1570        self.distinct_by((0..arity).collect())
1571    }
1572
1573    /// Removes all but the first occurrence of each key. Columns not included
1574    /// in the `group_key` are discarded.
1575    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1576        self.reduce(group_key, vec![], None)
1577    }
1578
1579    /// Discards rows with a negative frequency.
1580    pub fn threshold(self) -> Self {
1581        if let MirRelationExpr::Threshold { .. } = &self {
1582            self
1583        } else {
1584            MirRelationExpr::Threshold {
1585                input: Box::new(self),
1586            }
1587        }
1588    }
1589
1590    /// Unions together any number inputs.
1591    ///
1592    /// If `inputs` is empty, then an empty relation of type `typ` is
1593    /// constructed.
1594    pub fn union_many(mut inputs: Vec<Self>, typ: RelationType) -> Self {
1595        // Deconstruct `inputs` as `Union`s and reconstitute.
1596        let mut flat_inputs = Vec::with_capacity(inputs.len());
1597        for input in inputs {
1598            if let MirRelationExpr::Union { base, inputs } = input {
1599                flat_inputs.push(*base);
1600                flat_inputs.extend(inputs);
1601            } else {
1602                flat_inputs.push(input);
1603            }
1604        }
1605        inputs = flat_inputs;
1606        if inputs.len() == 0 {
1607            MirRelationExpr::Constant {
1608                rows: Ok(vec![]),
1609                typ,
1610            }
1611        } else if inputs.len() == 1 {
1612            inputs.into_element()
1613        } else {
1614            MirRelationExpr::Union {
1615                base: Box::new(inputs.remove(0)),
1616                inputs,
1617            }
1618        }
1619    }
1620
1621    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1622    pub fn union(self, other: Self) -> Self {
1623        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1624        let mut flat_inputs = Vec::with_capacity(2);
1625        if let MirRelationExpr::Union { base, inputs } = self {
1626            flat_inputs.push(*base);
1627            flat_inputs.extend(inputs);
1628        } else {
1629            flat_inputs.push(self);
1630        }
1631        if let MirRelationExpr::Union { base, inputs } = other {
1632            flat_inputs.push(*base);
1633            flat_inputs.extend(inputs);
1634        } else {
1635            flat_inputs.push(other);
1636        }
1637
1638        MirRelationExpr::Union {
1639            base: Box::new(flat_inputs.remove(0)),
1640            inputs: flat_inputs,
1641        }
1642    }
1643
1644    /// Arranges the collection by the specified columns
1645    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1646        MirRelationExpr::ArrangeBy {
1647            input: Box::new(self),
1648            keys: keys.to_owned(),
1649        }
1650    }
1651
1652    /// Indicates if this is a constant empty collection.
1653    ///
1654    /// A false value does not mean the collection is known to be non-empty,
1655    /// only that we cannot currently determine that it is statically empty.
1656    pub fn is_empty(&self) -> bool {
1657        if let Some((Ok(rows), ..)) = self.as_const() {
1658            rows.is_empty()
1659        } else {
1660            false
1661        }
1662    }
1663
1664    /// If the expression is a negated project, return the input and the projection.
1665    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1666        if let MirRelationExpr::Negate { input } = self {
1667            if let MirRelationExpr::Project { input, outputs } = &**input {
1668                return Some((&**input, outputs));
1669            }
1670        }
1671        if let MirRelationExpr::Project { input, outputs } = self {
1672            if let MirRelationExpr::Negate { input } = &**input {
1673                return Some((&**input, outputs));
1674            }
1675        }
1676        None
1677    }
1678
1679    /// Pretty-print this [MirRelationExpr] to a string.
1680    pub fn pretty(&self) -> String {
1681        let config = ExplainConfig::default();
1682        self.explain(&config, None)
1683    }
1684
1685    /// Pretty-print this [MirRelationExpr] to a string using a custom
1686    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1687    pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1688        text_string_at(self, || PlanRenderingContext {
1689            indent: Indent::default(),
1690            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1691            annotations: BTreeMap::default(),
1692            config,
1693        })
1694    }
1695
1696    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1697    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1698    /// would find. Keys and nullability are ignored in the given `RelationType`, and instead we set
1699    /// the best possible key and nullability, since we are making an empty collection.
1700    ///
1701    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1702    /// the correct type.
1703    pub fn take_safely(&mut self, typ: Option<RelationType>) -> MirRelationExpr {
1704        if let Some(typ) = &typ {
1705            soft_assert_no_log!(
1706                self.typ()
1707                    .column_types
1708                    .iter()
1709                    .zip_eq(typ.column_types.iter())
1710                    .all(|(t1, t2)| t1.scalar_type.base_eq(&t2.scalar_type))
1711            );
1712        }
1713        let mut typ = typ.unwrap_or_else(|| self.typ());
1714        typ.keys = vec![vec![]];
1715        for ct in typ.column_types.iter_mut() {
1716            ct.nullable = false;
1717        }
1718        std::mem::replace(
1719            self,
1720            MirRelationExpr::Constant {
1721                rows: Ok(vec![]),
1722                typ,
1723            },
1724        )
1725    }
1726
1727    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1728    /// types. Nullability is ignored in the given `ColumnType`s, and instead we set the best
1729    /// possible nullability, since we are making an empty collection.
1730    pub fn take_safely_with_col_types(&mut self, typ: Vec<ColumnType>) -> MirRelationExpr {
1731        self.take_safely(Some(RelationType::new(typ)))
1732    }
1733
1734    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1735    ///
1736    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1737    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1738        let empty = MirRelationExpr::Constant {
1739            rows: Ok(vec![]),
1740            typ: RelationType::new(Vec::new()),
1741        };
1742        std::mem::replace(self, empty)
1743    }
1744
1745    /// Replaces `self` with some logic applied to `self`.
1746    pub fn replace_using<F>(&mut self, logic: F)
1747    where
1748        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1749    {
1750        let empty = MirRelationExpr::Constant {
1751            rows: Ok(vec![]),
1752            typ: RelationType::new(Vec::new()),
1753        };
1754        let expr = std::mem::replace(self, empty);
1755        *self = logic(expr);
1756    }
1757
1758    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1759    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1760    where
1761        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1762    {
1763        if let MirRelationExpr::Get { .. } = self {
1764            // already done
1765            body(id_gen, self)
1766        } else {
1767            let id = LocalId::new(id_gen.allocate_id());
1768            let get = MirRelationExpr::Get {
1769                id: Id::Local(id),
1770                typ: self.typ(),
1771                access_strategy: AccessStrategy::UnknownOrLocal,
1772            };
1773            let body = (body)(id_gen, get)?;
1774            Ok(MirRelationExpr::Let {
1775                id,
1776                value: Box::new(self),
1777                body: Box::new(body),
1778            })
1779        }
1780    }
1781
1782    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1783    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1784    pub fn anti_lookup<E>(
1785        self,
1786        id_gen: &mut IdGen,
1787        keys_and_values: MirRelationExpr,
1788        default: Vec<(Datum, ScalarType)>,
1789    ) -> Result<MirRelationExpr, E> {
1790        let (data, column_types): (Vec<_>, Vec<_>) = default
1791            .into_iter()
1792            .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1793            .unzip();
1794        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1795        self.let_in(id_gen, |_id_gen, get_keys| {
1796            let get_keys_arity = get_keys.arity();
1797            Ok(MirRelationExpr::join(
1798                vec![
1799                    // all the missing keys (with count 1)
1800                    keys_and_values
1801                        .distinct_by((0..get_keys_arity).collect())
1802                        .negate()
1803                        .union(get_keys.clone().distinct()),
1804                    // join with keys to get the correct counts
1805                    get_keys.clone(),
1806                ],
1807                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1808            )
1809            // get rid of the extra copies of columns from keys
1810            .project((0..get_keys_arity).collect())
1811            // This join is logically equivalent to
1812            // `.map(<default_expr>)`, but using a join allows for
1813            // potential predicate pushdown and elision in the
1814            // optimizer.
1815            .product(MirRelationExpr::constant(
1816                vec![data],
1817                RelationType::new(column_types),
1818            )))
1819        })
1820    }
1821
1822    /// Return:
1823    /// * every row in keys_and_values
1824    /// * every row in `self` that does not have a matching row in the first columns of
1825    ///   `keys_and_values`, using `default` to fill in the remaining columns
1826    /// (This is LEFT OUTER JOIN if:
1827    /// 1) `default` is a row of null
1828    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1829    pub fn lookup<E>(
1830        self,
1831        id_gen: &mut IdGen,
1832        keys_and_values: MirRelationExpr,
1833        default: Vec<(Datum<'static>, ScalarType)>,
1834    ) -> Result<MirRelationExpr, E> {
1835        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1836            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1837                id_gen,
1838                get_keys_and_values,
1839                default,
1840            )?))
1841        })
1842    }
1843
1844    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1845    pub fn contains_temporal(&self) -> bool {
1846        let mut contains = false;
1847        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1848        contains
1849    }
1850
1851    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1852    ///
1853    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1854    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1855    where
1856        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1857    {
1858        use MirRelationExpr::*;
1859        match self {
1860            Map { scalars, .. } => {
1861                for s in scalars {
1862                    f(s)?;
1863                }
1864            }
1865            Filter { predicates, .. } => {
1866                for p in predicates {
1867                    f(p)?;
1868                }
1869            }
1870            FlatMap { exprs, .. } => {
1871                for expr in exprs {
1872                    f(expr)?;
1873                }
1874            }
1875            Join {
1876                inputs: _,
1877                equivalences,
1878                implementation,
1879            } => {
1880                for equivalence in equivalences {
1881                    for expr in equivalence {
1882                        f(expr)?;
1883                    }
1884                }
1885                match implementation {
1886                    JoinImplementation::Differential((_, start_key, _), order) => {
1887                        if let Some(start_key) = start_key {
1888                            for k in start_key {
1889                                f(k)?;
1890                            }
1891                        }
1892                        for (_, lookup_key, _) in order {
1893                            for k in lookup_key {
1894                                f(k)?;
1895                            }
1896                        }
1897                    }
1898                    JoinImplementation::DeltaQuery(paths) => {
1899                        for path in paths {
1900                            for (_, lookup_key, _) in path {
1901                                for k in lookup_key {
1902                                    f(k)?;
1903                                }
1904                            }
1905                        }
1906                    }
1907                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1908                        for k in index_key {
1909                            f(k)?;
1910                        }
1911                    }
1912                    JoinImplementation::Unimplemented => {} // No scalar exprs
1913                }
1914            }
1915            ArrangeBy { keys, .. } => {
1916                for key in keys {
1917                    for s in key {
1918                        f(s)?;
1919                    }
1920                }
1921            }
1922            Reduce {
1923                group_key,
1924                aggregates,
1925                ..
1926            } => {
1927                for s in group_key {
1928                    f(s)?;
1929                }
1930                for agg in aggregates {
1931                    f(&mut agg.expr)?;
1932                }
1933            }
1934            TopK { limit, .. } => {
1935                if let Some(s) = limit {
1936                    f(s)?;
1937                }
1938            }
1939            Constant { .. }
1940            | Get { .. }
1941            | Let { .. }
1942            | LetRec { .. }
1943            | Project { .. }
1944            | Negate { .. }
1945            | Threshold { .. }
1946            | Union { .. } => (),
1947        }
1948        Ok(())
1949    }
1950
1951    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1952    /// rooted at `self`.
1953    ///
1954    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1955    /// nodes.
1956    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1957    where
1958        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1959        E: From<RecursionLimitError>,
1960    {
1961        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1962    }
1963
1964    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1965    /// rooted at `self`.
1966    ///
1967    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1968    /// nodes.
1969    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1970    where
1971        F: FnMut(&mut MirScalarExpr),
1972    {
1973        self.try_visit_scalars_mut(&mut |s| {
1974            f(s);
1975            Ok::<_, RecursionLimitError>(())
1976        })
1977        .expect("Unexpected error in `visit_scalars_mut` call");
1978    }
1979
1980    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1981    ///
1982    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1983    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1984    where
1985        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1986    {
1987        use MirRelationExpr::*;
1988        match self {
1989            Map { scalars, .. } => {
1990                for s in scalars {
1991                    f(s)?;
1992                }
1993            }
1994            Filter { predicates, .. } => {
1995                for p in predicates {
1996                    f(p)?;
1997                }
1998            }
1999            FlatMap { exprs, .. } => {
2000                for expr in exprs {
2001                    f(expr)?;
2002                }
2003            }
2004            Join {
2005                inputs: _,
2006                equivalences,
2007                implementation,
2008            } => {
2009                for equivalence in equivalences {
2010                    for expr in equivalence {
2011                        f(expr)?;
2012                    }
2013                }
2014                match implementation {
2015                    JoinImplementation::Differential((_, start_key, _), order) => {
2016                        if let Some(start_key) = start_key {
2017                            for k in start_key {
2018                                f(k)?;
2019                            }
2020                        }
2021                        for (_, lookup_key, _) in order {
2022                            for k in lookup_key {
2023                                f(k)?;
2024                            }
2025                        }
2026                    }
2027                    JoinImplementation::DeltaQuery(paths) => {
2028                        for path in paths {
2029                            for (_, lookup_key, _) in path {
2030                                for k in lookup_key {
2031                                    f(k)?;
2032                                }
2033                            }
2034                        }
2035                    }
2036                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2037                        for k in index_key {
2038                            f(k)?;
2039                        }
2040                    }
2041                    JoinImplementation::Unimplemented => {} // No scalar exprs
2042                }
2043            }
2044            ArrangeBy { keys, .. } => {
2045                for key in keys {
2046                    for s in key {
2047                        f(s)?;
2048                    }
2049                }
2050            }
2051            Reduce {
2052                group_key,
2053                aggregates,
2054                ..
2055            } => {
2056                for s in group_key {
2057                    f(s)?;
2058                }
2059                for agg in aggregates {
2060                    f(&agg.expr)?;
2061                }
2062            }
2063            TopK { limit, .. } => {
2064                if let Some(s) = limit {
2065                    f(s)?;
2066                }
2067            }
2068            Constant { .. }
2069            | Get { .. }
2070            | Let { .. }
2071            | LetRec { .. }
2072            | Project { .. }
2073            | Negate { .. }
2074            | Threshold { .. }
2075            | Union { .. } => (),
2076        }
2077        Ok(())
2078    }
2079
2080    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2081    /// rooted at `self`.
2082    ///
2083    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2084    /// nodes.
2085    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2086    where
2087        F: FnMut(&MirScalarExpr) -> Result<(), E>,
2088        E: From<RecursionLimitError>,
2089    {
2090        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2091    }
2092
2093    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2094    /// rooted at `self`.
2095    ///
2096    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2097    /// nodes.
2098    pub fn visit_scalars<F>(&self, f: &mut F)
2099    where
2100        F: FnMut(&MirScalarExpr),
2101    {
2102        self.try_visit_scalars(&mut |s| {
2103            f(s);
2104            Ok::<_, RecursionLimitError>(())
2105        })
2106        .expect("Unexpected error in `visit_scalars` call");
2107    }
2108
2109    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2110    /// stack overflow in `drop_in_place`.
2111    ///
2112    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2113    /// dropped or otherwise overwritten.
2114    pub fn destroy_carefully(&mut self) {
2115        let mut todo = vec![self.take_dangerous()];
2116        while let Some(mut expr) = todo.pop() {
2117            for child in expr.children_mut() {
2118                todo.push(child.take_dangerous());
2119            }
2120        }
2121    }
2122
2123    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2124    /// debug printing purposes.
2125    pub fn debug_size_and_depth(&self) -> (usize, usize) {
2126        let mut size = 0;
2127        let mut max_depth = 0;
2128        let mut todo = vec![(self, 1)];
2129        while let Some((expr, depth)) = todo.pop() {
2130            size += 1;
2131            max_depth = max(max_depth, depth);
2132            todo.extend(expr.children().map(|c| (c, depth + 1)));
2133        }
2134        (size, max_depth)
2135    }
2136
2137    /// The MirRelationExpr is considered potentially expensive if and only if
2138    /// at least one of the following conditions is true:
2139    ///
2140    ///  - It contains at least one FlatMap or a Reduce operator.
2141    ///  - It contains at least one MirScalarExpr with a function call.
2142    ///
2143    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2144    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2145    pub fn could_run_expensive_function(&self) -> bool {
2146        let mut result = false;
2147        self.visit_pre(|e: &MirRelationExpr| {
2148            use MirRelationExpr::*;
2149            use MirScalarExpr::*;
2150            if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2151                result |= match scalar {
2152                    Column(_) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2153                    // Function calls are considered expensive
2154                    CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2155                };
2156                Ok(())
2157            }) {
2158                // Conservatively set `true` if on RecursionLimitError.
2159                result = true;
2160            }
2161            // FlatMap has a table function; Reduce has an aggregate function.
2162            // Other constructs use MirScalarExpr to run a function
2163            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2164        });
2165        result
2166    }
2167
2168    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2169    /// than what `Hashable::hashed` would give us.)
2170    pub fn hash_to_u64(&self) -> u64 {
2171        let mut h = DefaultHasher::new();
2172        self.hash(&mut h);
2173        h.finish()
2174    }
2175}
2176
2177// `LetRec` helpers
2178impl MirRelationExpr {
2179    /// True when `expr` contains a `LetRec` AST node.
2180    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2181        let mut worklist = vec![self];
2182        while let Some(expr) = worklist.pop() {
2183            if let MirRelationExpr::LetRec { .. } = expr {
2184                return true;
2185            }
2186            worklist.extend(expr.children());
2187        }
2188        false
2189    }
2190
2191    /// Return the number of sub-expressions in the tree (including self).
2192    pub fn size(&self) -> usize {
2193        let mut size = 0;
2194        self.visit_pre(|_| size += 1);
2195        size
2196    }
2197
2198    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2199    /// iterations. These are those ids that have a reference before they are defined, when reading
2200    /// all the bindings in order.
2201    ///
2202    /// For example:
2203    /// ```SQL
2204    /// WITH MUTUALLY RECURSIVE
2205    ///     x(...) AS f(z),
2206    ///     y(...) AS g(x),
2207    ///     z(...) AS h(y)
2208    /// ...;
2209    /// ```
2210    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2211    /// iteration.
2212    ///
2213    /// Note that if a binding references itself, that is also returned.
2214    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2215        let mut used_across_iterations = BTreeSet::new();
2216        let mut defined = BTreeSet::new();
2217        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2218            value.visit_pre(|expr| {
2219                if let MirRelationExpr::Get {
2220                    id: Local(get_id), ..
2221                } = expr
2222                {
2223                    // If we haven't seen a definition for it yet, then this will refer
2224                    // to the previous iteration.
2225                    // The `ids.contains` part of the condition is needed to exclude
2226                    // those ids that are not really in this LetRec, but either an inner
2227                    // or outer one.
2228                    if !defined.contains(get_id) && ids.contains(get_id) {
2229                        used_across_iterations.insert(*get_id);
2230                    }
2231                }
2232            });
2233            defined.insert(*binding_id);
2234        }
2235        used_across_iterations
2236    }
2237
2238    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2239    ///
2240    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2241    /// identifiers are rewritten to be the constant collection.
2242    /// This makes the computation perform exactly "one" iteration.
2243    ///
2244    /// This was used only temporarily while developing `LetRec`.
2245    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2246        let mut deadlist = BTreeSet::new();
2247        let mut worklist = vec![self];
2248        while let Some(expr) = worklist.pop() {
2249            if let MirRelationExpr::LetRec {
2250                ids,
2251                values,
2252                limits: _,
2253                body,
2254            } = expr
2255            {
2256                let ids_values = values
2257                    .drain(..)
2258                    .zip(ids)
2259                    .map(|(value, id)| (*id, value))
2260                    .collect::<Vec<_>>();
2261                *expr = body.take_dangerous();
2262                for (id, mut value) in ids_values.into_iter().rev() {
2263                    // Remove references to potentially recursive identifiers.
2264                    deadlist.insert(id);
2265                    value.visit_pre_mut(|e| {
2266                        if let MirRelationExpr::Get {
2267                            id: crate::Id::Local(id),
2268                            typ,
2269                            ..
2270                        } = e
2271                        {
2272                            let typ = typ.clone();
2273                            if deadlist.contains(id) {
2274                                e.take_safely(Some(typ));
2275                            }
2276                        }
2277                    });
2278                    *expr = MirRelationExpr::Let {
2279                        id,
2280                        value: Box::new(value),
2281                        body: Box::new(expr.take_dangerous()),
2282                    };
2283                }
2284                worklist.push(expr);
2285            } else {
2286                worklist.extend(expr.children_mut().rev());
2287            }
2288        }
2289    }
2290
2291    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2292    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2293    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2294    /// redefinition.
2295    ///
2296    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2297    pub fn collect_expirations(
2298        id: LocalId,
2299        expr: &MirRelationExpr,
2300        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2301    ) {
2302        expr.visit_pre(|e| {
2303            if let MirRelationExpr::Get {
2304                id: Id::Local(referenced_id),
2305                ..
2306            } = e
2307            {
2308                // The following check needs `renumber_bindings` to have run recently
2309                if referenced_id >= &id {
2310                    expire_whens
2311                        .entry(*referenced_id)
2312                        .or_insert_with(Vec::new)
2313                        .push(id);
2314                }
2315            }
2316        });
2317    }
2318
2319    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2320    /// about such Ids whose information depended on the earlier definition of `id`, according to
2321    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2322    pub fn do_expirations<I>(
2323        redefined_id: LocalId,
2324        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2325        id_infos: &mut BTreeMap<LocalId, I>,
2326    ) -> Vec<(LocalId, I)> {
2327        let mut expired_infos = Vec::new();
2328        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2329            for expired_id in expirations.into_iter() {
2330                if let Some(offer) = id_infos.remove(&expired_id) {
2331                    expired_infos.push((expired_id, offer));
2332                }
2333            }
2334        }
2335        expired_infos
2336    }
2337}
2338/// Augment non-nullability of columns, by observing either
2339/// 1. Predicates that explicitly test for null values, and
2340/// 2. Columns that if null would make a predicate be null.
2341pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2342    let mut nonnull_required_columns = BTreeSet::new();
2343    for predicate in predicates {
2344        // Add any columns that being null would force the predicate to be null.
2345        // Should that happen, the row would be discarded.
2346        predicate.non_null_requirements(&mut nonnull_required_columns);
2347
2348        /*
2349        Test for explicit checks that a column is non-null.
2350
2351        This analysis is ad hoc, and will miss things:
2352
2353        materialize=> create table a(x int, y int);
2354        CREATE TABLE
2355        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2356        Optimized Plan
2357        --------------------------------------------------------------------------------------------------------
2358        Explained Query:                                                                                      +
2359        Project (#0) // { types: "(integer?)" }                                                             +
2360        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2361        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2362                                                                                  +
2363        Source materialize.public.a                                                                           +
2364        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2365
2366        (1 row)
2367        */
2368
2369        if let MirScalarExpr::CallUnary {
2370            func: UnaryFunc::Not(scalar_func::Not),
2371            expr,
2372        } = predicate
2373        {
2374            if let MirScalarExpr::CallUnary {
2375                func: UnaryFunc::IsNull(scalar_func::IsNull),
2376                expr,
2377            } = &**expr
2378            {
2379                if let MirScalarExpr::Column(c) = &**expr {
2380                    nonnull_required_columns.insert(*c);
2381                }
2382            }
2383        }
2384    }
2385
2386    nonnull_required_columns
2387}
2388
2389impl CollectionPlan for MirRelationExpr {
2390    // !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2391    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2392    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2393        if let MirRelationExpr::Get {
2394            id: Id::Global(id), ..
2395        } = self
2396        {
2397            out.insert(*id);
2398        }
2399        self.visit_children(|expr| expr.depends_on_into(out))
2400    }
2401}
2402
2403impl MirRelationExpr {
2404    /// Iterates through references to child expressions.
2405    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2406        let mut first = None;
2407        let mut second = None;
2408        let mut rest = None;
2409        let mut last = None;
2410
2411        use MirRelationExpr::*;
2412        match self {
2413            Constant { .. } | Get { .. } => (),
2414            Let { value, body, .. } => {
2415                first = Some(&**value);
2416                second = Some(&**body);
2417            }
2418            LetRec { values, body, .. } => {
2419                rest = Some(values);
2420                last = Some(&**body);
2421            }
2422            Project { input, .. }
2423            | Map { input, .. }
2424            | FlatMap { input, .. }
2425            | Filter { input, .. }
2426            | Reduce { input, .. }
2427            | TopK { input, .. }
2428            | Negate { input }
2429            | Threshold { input }
2430            | ArrangeBy { input, .. } => {
2431                first = Some(&**input);
2432            }
2433            Join { inputs, .. } => {
2434                rest = Some(inputs);
2435            }
2436            Union { base, inputs } => {
2437                first = Some(&**base);
2438                rest = Some(inputs);
2439            }
2440        }
2441
2442        first
2443            .into_iter()
2444            .chain(second)
2445            .chain(rest.into_iter().flatten())
2446            .chain(last)
2447    }
2448
2449    /// Iterates through mutable references to child expressions.
2450    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2451        let mut first = None;
2452        let mut second = None;
2453        let mut rest = None;
2454        let mut last = None;
2455
2456        use MirRelationExpr::*;
2457        match self {
2458            Constant { .. } | Get { .. } => (),
2459            Let { value, body, .. } => {
2460                first = Some(&mut **value);
2461                second = Some(&mut **body);
2462            }
2463            LetRec { values, body, .. } => {
2464                rest = Some(values);
2465                last = Some(&mut **body);
2466            }
2467            Project { input, .. }
2468            | Map { input, .. }
2469            | FlatMap { input, .. }
2470            | Filter { input, .. }
2471            | Reduce { input, .. }
2472            | TopK { input, .. }
2473            | Negate { input }
2474            | Threshold { input }
2475            | ArrangeBy { input, .. } => {
2476                first = Some(&mut **input);
2477            }
2478            Join { inputs, .. } => {
2479                rest = Some(inputs);
2480            }
2481            Union { base, inputs } => {
2482                first = Some(&mut **base);
2483                rest = Some(inputs);
2484            }
2485        }
2486
2487        first
2488            .into_iter()
2489            .chain(second)
2490            .chain(rest.into_iter().flatten())
2491            .chain(last)
2492    }
2493
2494    /// Iterative pre-order visitor.
2495    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2496        let mut worklist = vec![self];
2497        while let Some(expr) = worklist.pop() {
2498            f(expr);
2499            worklist.extend(expr.children().rev());
2500        }
2501    }
2502
2503    /// Iterative pre-order visitor.
2504    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2505        let mut worklist = vec![self];
2506        while let Some(expr) = worklist.pop() {
2507            f(expr);
2508            worklist.extend(expr.children_mut().rev());
2509        }
2510    }
2511
2512    /// Return a vector of references to the subtrees of this expression
2513    /// in post-visit order (the last element is `&self`).
2514    pub fn post_order_vec(&self) -> Vec<&Self> {
2515        let mut stack = vec![self];
2516        let mut result = vec![];
2517        while let Some(expr) = stack.pop() {
2518            result.push(expr);
2519            stack.extend(expr.children());
2520        }
2521        result.reverse();
2522        result
2523    }
2524}
2525
2526impl VisitChildren<Self> for MirRelationExpr {
2527    fn visit_children<F>(&self, mut f: F)
2528    where
2529        F: FnMut(&Self),
2530    {
2531        for child in self.children() {
2532            f(child)
2533        }
2534    }
2535
2536    fn visit_mut_children<F>(&mut self, mut f: F)
2537    where
2538        F: FnMut(&mut Self),
2539    {
2540        for child in self.children_mut() {
2541            f(child)
2542        }
2543    }
2544
2545    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2546    where
2547        F: FnMut(&Self) -> Result<(), E>,
2548        E: From<RecursionLimitError>,
2549    {
2550        for child in self.children() {
2551            f(child)?
2552        }
2553        Ok(())
2554    }
2555
2556    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2557    where
2558        F: FnMut(&mut Self) -> Result<(), E>,
2559        E: From<RecursionLimitError>,
2560    {
2561        for child in self.children_mut() {
2562            f(child)?
2563        }
2564        Ok(())
2565    }
2566}
2567
2568/// Specification for an ordering by a column.
2569#[derive(
2570    Arbitrary,
2571    Debug,
2572    Clone,
2573    Copy,
2574    Eq,
2575    PartialEq,
2576    Ord,
2577    PartialOrd,
2578    Serialize,
2579    Deserialize,
2580    Hash,
2581    MzReflect,
2582)]
2583pub struct ColumnOrder {
2584    /// The column index.
2585    pub column: usize,
2586    /// Whether to sort in descending order.
2587    #[serde(default)]
2588    pub desc: bool,
2589    /// Whether to sort nulls last.
2590    #[serde(default)]
2591    pub nulls_last: bool,
2592}
2593
2594impl Columnation for ColumnOrder {
2595    type InnerRegion = CopyRegion<Self>;
2596}
2597
2598impl RustType<ProtoColumnOrder> for ColumnOrder {
2599    fn into_proto(&self) -> ProtoColumnOrder {
2600        ProtoColumnOrder {
2601            column: self.column.into_proto(),
2602            desc: self.desc,
2603            nulls_last: self.nulls_last,
2604        }
2605    }
2606
2607    fn from_proto(proto: ProtoColumnOrder) -> Result<Self, TryFromProtoError> {
2608        Ok(ColumnOrder {
2609            column: proto.column.into_rust()?,
2610            desc: proto.desc,
2611            nulls_last: proto.nulls_last,
2612        })
2613    }
2614}
2615
2616impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2617where
2618    M: HumanizerMode,
2619{
2620    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2621        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2622        write!(
2623            f,
2624            "{} {} {}",
2625            self.child(&self.expr.column),
2626            if self.expr.desc { "desc" } else { "asc" },
2627            if self.expr.nulls_last {
2628                "nulls_last"
2629            } else {
2630                "nulls_first"
2631            },
2632        )
2633    }
2634}
2635
2636/// Describes an aggregation expression.
2637#[derive(
2638    Arbitrary, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
2639)]
2640pub struct AggregateExpr {
2641    /// Names the aggregation function.
2642    pub func: AggregateFunc,
2643    /// An expression which extracts from each row the input to `func`.
2644    pub expr: MirScalarExpr,
2645    /// Should the aggregation be applied only to distinct results in each group.
2646    #[serde(default)]
2647    pub distinct: bool,
2648}
2649
2650impl RustType<ProtoAggregateExpr> for AggregateExpr {
2651    fn into_proto(&self) -> ProtoAggregateExpr {
2652        ProtoAggregateExpr {
2653            func: Some(self.func.into_proto()),
2654            expr: Some(self.expr.into_proto()),
2655            distinct: self.distinct,
2656        }
2657    }
2658
2659    fn from_proto(proto: ProtoAggregateExpr) -> Result<Self, TryFromProtoError> {
2660        Ok(Self {
2661            func: proto.func.into_rust_if_some("ProtoAggregateExpr::func")?,
2662            expr: proto.expr.into_rust_if_some("ProtoAggregateExpr::expr")?,
2663            distinct: proto.distinct,
2664        })
2665    }
2666}
2667
2668impl AggregateExpr {
2669    /// Computes the type of this `AggregateExpr`.
2670    pub fn typ(&self, column_types: &[ColumnType]) -> ColumnType {
2671        self.func.output_type(self.expr.typ(column_types))
2672    }
2673
2674    /// Returns whether the expression has a constant result.
2675    pub fn is_constant(&self) -> bool {
2676        match self.func {
2677            AggregateFunc::MaxInt16
2678            | AggregateFunc::MaxInt32
2679            | AggregateFunc::MaxInt64
2680            | AggregateFunc::MaxUInt16
2681            | AggregateFunc::MaxUInt32
2682            | AggregateFunc::MaxUInt64
2683            | AggregateFunc::MaxMzTimestamp
2684            | AggregateFunc::MaxFloat32
2685            | AggregateFunc::MaxFloat64
2686            | AggregateFunc::MaxBool
2687            | AggregateFunc::MaxString
2688            | AggregateFunc::MaxDate
2689            | AggregateFunc::MaxTimestamp
2690            | AggregateFunc::MaxTimestampTz
2691            | AggregateFunc::MinInt16
2692            | AggregateFunc::MinInt32
2693            | AggregateFunc::MinInt64
2694            | AggregateFunc::MinUInt16
2695            | AggregateFunc::MinUInt32
2696            | AggregateFunc::MinUInt64
2697            | AggregateFunc::MinMzTimestamp
2698            | AggregateFunc::MinFloat32
2699            | AggregateFunc::MinFloat64
2700            | AggregateFunc::MinBool
2701            | AggregateFunc::MinString
2702            | AggregateFunc::MinDate
2703            | AggregateFunc::MinTimestamp
2704            | AggregateFunc::MinTimestampTz
2705            | AggregateFunc::Any
2706            | AggregateFunc::All
2707            | AggregateFunc::Dummy => self.expr.is_literal(),
2708            AggregateFunc::Count => self.expr.is_literal_null(),
2709            _ => self.expr.is_literal_err(),
2710        }
2711    }
2712
2713    /// Returns an expression that computes `self` on a group that has exactly one row.
2714    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2715    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2716    pub fn on_unique(&self, input_type: &[ColumnType]) -> MirScalarExpr {
2717        match &self.func {
2718            // Count is one if non-null, and zero if null.
2719            AggregateFunc::Count => self
2720                .expr
2721                .clone()
2722                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2723                .if_then_else(
2724                    MirScalarExpr::literal_ok(Datum::Int64(0), ScalarType::Int64),
2725                    MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
2726                ),
2727
2728            // SumInt16 takes Int16s as input, but outputs Int64s.
2729            AggregateFunc::SumInt16 => self
2730                .expr
2731                .clone()
2732                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2733
2734            // SumInt32 takes Int32s as input, but outputs Int64s.
2735            AggregateFunc::SumInt32 => self
2736                .expr
2737                .clone()
2738                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2739
2740            // SumInt64 takes Int64s as input, but outputs numerics.
2741            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2742                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2743            )),
2744
2745            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2746            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2747                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2748            ),
2749
2750            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2751            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2752                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2753            ),
2754
2755            // SumUInt64 takes UInt64s as input, but outputs numerics.
2756            AggregateFunc::SumUInt64 => {
2757                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2758                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2759                ))
2760            }
2761
2762            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2763            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2764                func: VariadicFunc::JsonbBuildArray,
2765                exprs: vec![
2766                    self.expr
2767                        .clone()
2768                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2769                ],
2770            },
2771
2772            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2773            AggregateFunc::JsonbObjectAgg { .. } => {
2774                let record = self
2775                    .expr
2776                    .clone()
2777                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2778                MirScalarExpr::CallVariadic {
2779                    func: VariadicFunc::JsonbBuildObject,
2780                    exprs: (0..2)
2781                        .map(|i| {
2782                            record
2783                                .clone()
2784                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2785                        })
2786                        .collect(),
2787                }
2788            }
2789
2790            AggregateFunc::MapAgg { value_type, .. } => {
2791                let record = self
2792                    .expr
2793                    .clone()
2794                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2795                MirScalarExpr::CallVariadic {
2796                    func: VariadicFunc::MapBuild {
2797                        value_type: value_type.clone(),
2798                    },
2799                    exprs: (0..2)
2800                        .map(|i| {
2801                            record
2802                                .clone()
2803                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2804                        })
2805                        .collect(),
2806                }
2807            }
2808
2809            // StringAgg takes nested records of strings and outputs a string
2810            AggregateFunc::StringAgg { .. } => self
2811                .expr
2812                .clone()
2813                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2814                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2815
2816            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2817            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2818                .expr
2819                .clone()
2820                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2821
2822            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2823            AggregateFunc::RowNumber { .. } => {
2824                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2825            }
2826            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2827            AggregateFunc::DenseRank { .. } => {
2828                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2829            }
2830
2831            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2832            AggregateFunc::LagLead { lag_lead, .. } => {
2833                let tuple = self
2834                    .expr
2835                    .clone()
2836                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2837
2838                // Get the overall return type
2839                let return_type_with_orig_row = self
2840                    .typ(input_type)
2841                    .scalar_type
2842                    .unwrap_list_element_type()
2843                    .clone();
2844                let lag_lead_return_type =
2845                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2846
2847                // Extract the original row
2848                let original_row = tuple
2849                    .clone()
2850                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2851
2852                // Extract the encoded args
2853                let encoded_args =
2854                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2855
2856                let (result_expr, column_name) =
2857                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2858
2859                MirScalarExpr::CallVariadic {
2860                    func: VariadicFunc::ListCreate {
2861                        elem_type: return_type_with_orig_row,
2862                    },
2863                    exprs: vec![MirScalarExpr::CallVariadic {
2864                        func: VariadicFunc::RecordCreate {
2865                            field_names: vec![column_name, ColumnName::from("?record?")],
2866                        },
2867                        exprs: vec![result_expr, original_row],
2868                    }],
2869                }
2870            }
2871
2872            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2873            AggregateFunc::FirstValue { window_frame, .. } => {
2874                let tuple = self
2875                    .expr
2876                    .clone()
2877                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2878
2879                // Get the overall return type
2880                let return_type_with_orig_row = self
2881                    .typ(input_type)
2882                    .scalar_type
2883                    .unwrap_list_element_type()
2884                    .clone();
2885                let first_value_return_type =
2886                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2887
2888                // Extract the original row
2889                let original_row = tuple
2890                    .clone()
2891                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2892
2893                // Extract the input value
2894                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2895
2896                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2897                    window_frame,
2898                    arg,
2899                    first_value_return_type,
2900                );
2901
2902                MirScalarExpr::CallVariadic {
2903                    func: VariadicFunc::ListCreate {
2904                        elem_type: return_type_with_orig_row,
2905                    },
2906                    exprs: vec![MirScalarExpr::CallVariadic {
2907                        func: VariadicFunc::RecordCreate {
2908                            field_names: vec![column_name, ColumnName::from("?record?")],
2909                        },
2910                        exprs: vec![result_expr, original_row],
2911                    }],
2912                }
2913            }
2914
2915            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2916            AggregateFunc::LastValue { window_frame, .. } => {
2917                let tuple = self
2918                    .expr
2919                    .clone()
2920                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2921
2922                // Get the overall return type
2923                let return_type_with_orig_row = self
2924                    .typ(input_type)
2925                    .scalar_type
2926                    .unwrap_list_element_type()
2927                    .clone();
2928                let last_value_return_type =
2929                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2930
2931                // Extract the original row
2932                let original_row = tuple
2933                    .clone()
2934                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2935
2936                // Extract the input value
2937                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2938
2939                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2940                    window_frame,
2941                    arg,
2942                    last_value_return_type,
2943                );
2944
2945                MirScalarExpr::CallVariadic {
2946                    func: VariadicFunc::ListCreate {
2947                        elem_type: return_type_with_orig_row,
2948                    },
2949                    exprs: vec![MirScalarExpr::CallVariadic {
2950                        func: VariadicFunc::RecordCreate {
2951                            field_names: vec![column_name, ColumnName::from("?record?")],
2952                        },
2953                        exprs: vec![result_expr, original_row],
2954                    }],
2955                }
2956            }
2957
2958            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2959            // See an example MIR in `window_func_applied_to`.
2960            AggregateFunc::WindowAggregate {
2961                wrapped_aggregate,
2962                window_frame,
2963                order_by: _,
2964            } => {
2965                // TODO: deduplicate code between the various window function cases.
2966
2967                let tuple = self
2968                    .expr
2969                    .clone()
2970                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2971
2972                // Get the overall return type
2973                let return_type = self
2974                    .typ(input_type)
2975                    .scalar_type
2976                    .unwrap_list_element_type()
2977                    .clone();
2978                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2979
2980                // Extract the original row
2981                let original_row = tuple
2982                    .clone()
2983                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2984
2985                // Extract the input value
2986                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2987
2988                let (result, column_name) = Self::on_unique_window_agg(
2989                    window_frame,
2990                    arg_expr,
2991                    input_type,
2992                    window_agg_return_type,
2993                    wrapped_aggregate,
2994                );
2995
2996                MirScalarExpr::CallVariadic {
2997                    func: VariadicFunc::ListCreate {
2998                        elem_type: return_type,
2999                    },
3000                    exprs: vec![MirScalarExpr::CallVariadic {
3001                        func: VariadicFunc::RecordCreate {
3002                            field_names: vec![column_name, ColumnName::from("?record?")],
3003                        },
3004                        exprs: vec![result, original_row],
3005                    }],
3006                }
3007            }
3008
3009            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
3010            AggregateFunc::FusedWindowAggregate {
3011                wrapped_aggregates,
3012                order_by: _,
3013                window_frame,
3014            } => {
3015                // Throw away OrderByExprs
3016                let tuple = self
3017                    .expr
3018                    .clone()
3019                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3020
3021                // Extract the original row
3022                let original_row = tuple
3023                    .clone()
3024                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3025
3026                // Extract the args of the fused call
3027                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3028
3029                let return_type_with_orig_row = self
3030                    .typ(input_type)
3031                    .scalar_type
3032                    .unwrap_list_element_type()
3033                    .clone();
3034
3035                let all_func_return_types =
3036                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3037                let mut func_result_exprs = Vec::new();
3038                let mut col_names = Vec::new();
3039                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
3040                    let arg = all_args
3041                        .clone()
3042                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3043                    let return_type =
3044                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3045                    let (result, column_name) = Self::on_unique_window_agg(
3046                        window_frame,
3047                        arg,
3048                        input_type,
3049                        return_type,
3050                        wrapped_aggr,
3051                    );
3052                    func_result_exprs.push(result);
3053                    col_names.push(column_name);
3054                }
3055
3056                MirScalarExpr::CallVariadic {
3057                    func: VariadicFunc::ListCreate {
3058                        elem_type: return_type_with_orig_row,
3059                    },
3060                    exprs: vec![MirScalarExpr::CallVariadic {
3061                        func: VariadicFunc::RecordCreate {
3062                            field_names: vec![
3063                                ColumnName::from("?fused_window_aggr?"),
3064                                ColumnName::from("?record?"),
3065                            ],
3066                        },
3067                        exprs: vec![
3068                            MirScalarExpr::CallVariadic {
3069                                func: VariadicFunc::RecordCreate {
3070                                    field_names: col_names,
3071                                },
3072                                exprs: func_result_exprs,
3073                            },
3074                            original_row,
3075                        ],
3076                    }],
3077                }
3078            }
3079
3080            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3081            AggregateFunc::FusedValueWindowFunc {
3082                funcs,
3083                order_by: outer_order_by,
3084            } => {
3085                // Throw away OrderByExprs
3086                let tuple = self
3087                    .expr
3088                    .clone()
3089                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3090
3091                // Extract the original row
3092                let original_row = tuple
3093                    .clone()
3094                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3095
3096                // Extract the encoded args of the fused call
3097                let all_encoded_args =
3098                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3099
3100                let return_type_with_orig_row = self
3101                    .typ(input_type)
3102                    .scalar_type
3103                    .unwrap_list_element_type()
3104                    .clone();
3105
3106                let all_func_return_types =
3107                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3108                let mut func_result_exprs = Vec::new();
3109                let mut col_names = Vec::new();
3110                for (idx, func) in funcs.iter().enumerate() {
3111                    let args_for_func = all_encoded_args
3112                        .clone()
3113                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3114                    let return_type_for_func =
3115                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3116                    let (result, column_name) = match func {
3117                        AggregateFunc::LagLead {
3118                            lag_lead,
3119                            order_by,
3120                            ignore_nulls: _,
3121                        } => {
3122                            assert_eq!(order_by, outer_order_by);
3123                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3124                        }
3125                        AggregateFunc::FirstValue {
3126                            window_frame,
3127                            order_by,
3128                        } => {
3129                            assert_eq!(order_by, outer_order_by);
3130                            Self::on_unique_first_value_last_value(
3131                                window_frame,
3132                                args_for_func,
3133                                return_type_for_func,
3134                            )
3135                        }
3136                        AggregateFunc::LastValue {
3137                            window_frame,
3138                            order_by,
3139                        } => {
3140                            assert_eq!(order_by, outer_order_by);
3141                            Self::on_unique_first_value_last_value(
3142                                window_frame,
3143                                args_for_func,
3144                                return_type_for_func,
3145                            )
3146                        }
3147                        _ => panic!("unknown function in FusedValueWindowFunc"),
3148                    };
3149                    func_result_exprs.push(result);
3150                    col_names.push(column_name);
3151                }
3152
3153                MirScalarExpr::CallVariadic {
3154                    func: VariadicFunc::ListCreate {
3155                        elem_type: return_type_with_orig_row,
3156                    },
3157                    exprs: vec![MirScalarExpr::CallVariadic {
3158                        func: VariadicFunc::RecordCreate {
3159                            field_names: vec![
3160                                ColumnName::from("?fused_value_window_func?"),
3161                                ColumnName::from("?record?"),
3162                            ],
3163                        },
3164                        exprs: vec![
3165                            MirScalarExpr::CallVariadic {
3166                                func: VariadicFunc::RecordCreate {
3167                                    field_names: col_names,
3168                                },
3169                                exprs: func_result_exprs,
3170                            },
3171                            original_row,
3172                        ],
3173                    }],
3174                }
3175            }
3176
3177            // All other variants should return the argument to the aggregation.
3178            AggregateFunc::MaxNumeric
3179            | AggregateFunc::MaxInt16
3180            | AggregateFunc::MaxInt32
3181            | AggregateFunc::MaxInt64
3182            | AggregateFunc::MaxUInt16
3183            | AggregateFunc::MaxUInt32
3184            | AggregateFunc::MaxUInt64
3185            | AggregateFunc::MaxMzTimestamp
3186            | AggregateFunc::MaxFloat32
3187            | AggregateFunc::MaxFloat64
3188            | AggregateFunc::MaxBool
3189            | AggregateFunc::MaxString
3190            | AggregateFunc::MaxDate
3191            | AggregateFunc::MaxTimestamp
3192            | AggregateFunc::MaxTimestampTz
3193            | AggregateFunc::MaxInterval
3194            | AggregateFunc::MaxTime
3195            | AggregateFunc::MinNumeric
3196            | AggregateFunc::MinInt16
3197            | AggregateFunc::MinInt32
3198            | AggregateFunc::MinInt64
3199            | AggregateFunc::MinUInt16
3200            | AggregateFunc::MinUInt32
3201            | AggregateFunc::MinUInt64
3202            | AggregateFunc::MinMzTimestamp
3203            | AggregateFunc::MinFloat32
3204            | AggregateFunc::MinFloat64
3205            | AggregateFunc::MinBool
3206            | AggregateFunc::MinString
3207            | AggregateFunc::MinDate
3208            | AggregateFunc::MinTimestamp
3209            | AggregateFunc::MinTimestampTz
3210            | AggregateFunc::MinInterval
3211            | AggregateFunc::MinTime
3212            | AggregateFunc::SumFloat32
3213            | AggregateFunc::SumFloat64
3214            | AggregateFunc::SumNumeric
3215            | AggregateFunc::Any
3216            | AggregateFunc::All
3217            | AggregateFunc::Dummy => self.expr.clone(),
3218        }
3219    }
3220
3221    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3222    fn on_unique_ranking_window_funcs(
3223        &self,
3224        input_type: &[ColumnType],
3225        col_name: &str,
3226    ) -> MirScalarExpr {
3227        let list = self
3228            .expr
3229            .clone()
3230            // extract the list within the record
3231            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3232
3233        // extract the expression within the list
3234        let record = MirScalarExpr::CallVariadic {
3235            func: VariadicFunc::ListIndex,
3236            exprs: vec![
3237                list,
3238                MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3239            ],
3240        };
3241
3242        MirScalarExpr::CallVariadic {
3243            func: VariadicFunc::ListCreate {
3244                elem_type: self
3245                    .typ(input_type)
3246                    .scalar_type
3247                    .unwrap_list_element_type()
3248                    .clone(),
3249            },
3250            exprs: vec![MirScalarExpr::CallVariadic {
3251                func: VariadicFunc::RecordCreate {
3252                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3253                },
3254                exprs: vec![
3255                    MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3256                    record,
3257                ],
3258            }],
3259        }
3260    }
3261
3262    /// `on_unique` for `lag` and `lead`
3263    fn on_unique_lag_lead(
3264        lag_lead: &LagLeadType,
3265        encoded_args: MirScalarExpr,
3266        return_type: ScalarType,
3267    ) -> (MirScalarExpr, ColumnName) {
3268        let expr = encoded_args
3269            .clone()
3270            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3271        let offset = encoded_args
3272            .clone()
3273            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3274        let default_value =
3275            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3276
3277        // In this case, the window always has only one element, so if the offset is not null and
3278        // not zero, the default value should be returned instead.
3279        let value = offset
3280            .clone()
3281            .call_binary(
3282                MirScalarExpr::literal_ok(Datum::Int32(0), ScalarType::Int32),
3283                crate::BinaryFunc::Eq,
3284            )
3285            .if_then_else(expr, default_value);
3286        let result_expr = offset
3287            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3288            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3289
3290        let column_name = ColumnName::from(match lag_lead {
3291            LagLeadType::Lag => "?lag?",
3292            LagLeadType::Lead => "?lead?",
3293        });
3294
3295        (result_expr, column_name)
3296    }
3297
3298    /// `on_unique` for `first_value` and `last_value`
3299    fn on_unique_first_value_last_value(
3300        window_frame: &WindowFrame,
3301        arg: MirScalarExpr,
3302        return_type: ScalarType,
3303    ) -> (MirScalarExpr, ColumnName) {
3304        // If the window frame includes the current (single) row, return its value, null otherwise
3305        let result_expr = if window_frame.includes_current_row() {
3306            arg
3307        } else {
3308            MirScalarExpr::literal_null(return_type)
3309        };
3310        (result_expr, ColumnName::from("?first_value?"))
3311    }
3312
3313    /// `on_unique` for window aggregations
3314    fn on_unique_window_agg(
3315        window_frame: &WindowFrame,
3316        arg_expr: MirScalarExpr,
3317        input_type: &[ColumnType],
3318        return_type: ScalarType,
3319        wrapped_aggr: &AggregateFunc,
3320    ) -> (MirScalarExpr, ColumnName) {
3321        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3322        // that row. Otherwise, return the default value for the aggregate.
3323        let result_expr = if window_frame.includes_current_row() {
3324            AggregateExpr {
3325                func: wrapped_aggr.clone(),
3326                expr: arg_expr,
3327                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3328            }
3329            .on_unique(input_type)
3330        } else {
3331            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3332        };
3333        (result_expr, ColumnName::from("?window_agg?"))
3334    }
3335
3336    /// Returns whether the expression is COUNT(*) or not.  Note that
3337    /// when we define the count builtin in sql::func, we convert
3338    /// COUNT(*) to COUNT(true), making it indistinguishable from
3339    /// literal COUNT(true), but we prefer to consider this as the
3340    /// former.
3341    ///
3342    /// (HIR has the same `is_count_asterisk`.)
3343    pub fn is_count_asterisk(&self) -> bool {
3344        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3345    }
3346}
3347
3348/// Describe a join implementation in dataflow.
3349#[derive(
3350    Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3351)]
3352pub enum JoinImplementation {
3353    /// Perform a sequence of binary differential dataflow joins.
3354    ///
3355    /// The first argument indicates
3356    /// 1) the index of the starting collection,
3357    /// 2) if it should be arranged, the keys to arrange it by, and
3358    /// 3) the characteristics of the starting collection (for EXPLAINing).
3359    /// The sequence that follows lists other relation indexes, and the key for
3360    /// the arrangement we should use when joining it in.
3361    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3362    /// were used for join ordering.
3363    ///
3364    /// Each collection index should occur exactly once, either as the starting collection
3365    /// or somewhere in the list.
3366    Differential(
3367        (
3368            usize,
3369            Option<Vec<MirScalarExpr>>,
3370            Option<JoinInputCharacteristics>,
3371        ),
3372        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3373    ),
3374    /// Perform independent delta query dataflows for each input.
3375    ///
3376    /// The argument is a sequence of plans, for the input collections in order.
3377    /// Each plan starts from the corresponding index, and then in sequence joins
3378    /// against collections identified by index and with the specified arrangement key.
3379    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3380    /// used for join ordering.
3381    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3382    /// Join a user-created index with a constant collection to speed up the evaluation of a
3383    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3384    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3385    /// to represent it in MIR, because the fast path detection wants to match on this.
3386    ///
3387    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3388    IndexedFilter(
3389        GlobalId,
3390        GlobalId,
3391        Vec<MirScalarExpr>,
3392        #[mzreflect(ignore)] Vec<Row>,
3393    ),
3394    /// No implementation yet selected.
3395    Unimplemented,
3396}
3397
3398impl Default for JoinImplementation {
3399    fn default() -> Self {
3400        JoinImplementation::Unimplemented
3401    }
3402}
3403
3404impl JoinImplementation {
3405    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3406    pub fn is_implemented(&self) -> bool {
3407        match self {
3408            Self::Unimplemented => false,
3409            _ => true,
3410        }
3411    }
3412
3413    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3414    pub fn name(&self) -> Option<&'static str> {
3415        match self {
3416            Self::Differential(..) => Some("differential"),
3417            Self::DeltaQuery(..) => Some("delta"),
3418            Self::IndexedFilter(..) => Some("indexed_filter"),
3419            Self::Unimplemented => None,
3420        }
3421    }
3422}
3423
3424/// Characteristics of a join order candidate collection.
3425///
3426/// A candidate is described by a collection and a key, and may have various liabilities.
3427/// Primarily, the candidate may risk substantial inflation of records, which is something
3428/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3429/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3430/// collections in the interest of consistent tie-breaking. For more characteristics, see
3431/// comments on individual fields.
3432///
3433/// This has more than one version. `new` instantiates the appropriate version based on a
3434/// feature flag.
3435#[derive(
3436    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3437)]
3438pub enum JoinInputCharacteristics {
3439    /// Old version, with `enable_join_prioritize_arranged` turned off.
3440    V1(JoinInputCharacteristicsV1),
3441    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3442    V2(JoinInputCharacteristicsV2),
3443}
3444
3445impl JoinInputCharacteristics {
3446    /// Creates a new instance with the given characteristics.
3447    pub fn new(
3448        unique_key: bool,
3449        key_length: usize,
3450        arranged: bool,
3451        cardinality: Option<usize>,
3452        filters: FilterCharacteristics,
3453        input: usize,
3454        enable_join_prioritize_arranged: bool,
3455    ) -> Self {
3456        if enable_join_prioritize_arranged {
3457            Self::V2(JoinInputCharacteristicsV2::new(
3458                unique_key,
3459                key_length,
3460                arranged,
3461                cardinality,
3462                filters,
3463                input,
3464            ))
3465        } else {
3466            Self::V1(JoinInputCharacteristicsV1::new(
3467                unique_key,
3468                key_length,
3469                arranged,
3470                cardinality,
3471                filters,
3472                input,
3473            ))
3474        }
3475    }
3476
3477    /// Turns the instance into a String to be printed in EXPLAIN.
3478    pub fn explain(&self) -> String {
3479        match self {
3480            Self::V1(jic) => jic.explain(),
3481            Self::V2(jic) => jic.explain(),
3482        }
3483    }
3484
3485    /// Whether the join input described by `self` is arranged.
3486    pub fn arranged(&self) -> bool {
3487        match self {
3488            Self::V1(jic) => jic.arranged,
3489            Self::V2(jic) => jic.arranged,
3490        }
3491    }
3492
3493    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3494    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3495        match self {
3496            Self::V1(jic) => &mut jic.filters,
3497            Self::V2(jic) => &mut jic.filters,
3498        }
3499    }
3500}
3501
3502/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3503#[derive(
3504    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3505)]
3506pub struct JoinInputCharacteristicsV2 {
3507    /// An excellent indication that record count will not increase.
3508    pub unique_key: bool,
3509    /// Cross joins are bad.
3510    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3511    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3512    /// but otherwise `key_length` is less important than `arranged`.)
3513    pub not_cross: bool,
3514    /// Indicates that there will be no additional in-memory footprint.
3515    pub arranged: bool,
3516    /// A weaker signal that record count will not increase.
3517    pub key_length: usize,
3518    /// Estimated cardinality (lower is better)
3519    pub cardinality: Option<std::cmp::Reverse<usize>>,
3520    /// Characteristics of the filter that is applied at this input.
3521    pub filters: FilterCharacteristics,
3522    /// We want to prefer input earlier in the input list, for stability of ordering.
3523    pub input: std::cmp::Reverse<usize>,
3524}
3525
3526impl JoinInputCharacteristicsV2 {
3527    /// Creates a new instance with the given characteristics.
3528    pub fn new(
3529        unique_key: bool,
3530        key_length: usize,
3531        arranged: bool,
3532        cardinality: Option<usize>,
3533        filters: FilterCharacteristics,
3534        input: usize,
3535    ) -> Self {
3536        Self {
3537            unique_key,
3538            not_cross: key_length > 0,
3539            arranged,
3540            key_length,
3541            cardinality: cardinality.map(std::cmp::Reverse),
3542            filters,
3543            input: std::cmp::Reverse(input),
3544        }
3545    }
3546
3547    /// Turns the instance into a String to be printed in EXPLAIN.
3548    pub fn explain(&self) -> String {
3549        let mut e = "".to_owned();
3550        if self.unique_key {
3551            e.push_str("U");
3552        }
3553        // Don't need to print `not_cross`, because that is visible in the printed key.
3554        // if !self.not_cross {
3555        //     e.push_str("C");
3556        // }
3557        for _ in 0..self.key_length {
3558            e.push_str("K");
3559        }
3560        if self.arranged {
3561            e.push_str("A");
3562        }
3563        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3564            e.push_str(&format!("|{cardinality}|"));
3565        }
3566        e.push_str(&self.filters.explain());
3567        e
3568    }
3569}
3570
3571/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3572#[derive(
3573    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3574)]
3575pub struct JoinInputCharacteristicsV1 {
3576    /// An excellent indication that record count will not increase.
3577    pub unique_key: bool,
3578    /// A weaker signal that record count will not increase.
3579    pub key_length: usize,
3580    /// Indicates that there will be no additional in-memory footprint.
3581    pub arranged: bool,
3582    /// Estimated cardinality (lower is better)
3583    pub cardinality: Option<std::cmp::Reverse<usize>>,
3584    /// Characteristics of the filter that is applied at this input.
3585    pub filters: FilterCharacteristics,
3586    /// We want to prefer input earlier in the input list, for stability of ordering.
3587    pub input: std::cmp::Reverse<usize>,
3588}
3589
3590impl JoinInputCharacteristicsV1 {
3591    /// Creates a new instance with the given characteristics.
3592    pub fn new(
3593        unique_key: bool,
3594        key_length: usize,
3595        arranged: bool,
3596        cardinality: Option<usize>,
3597        filters: FilterCharacteristics,
3598        input: usize,
3599    ) -> Self {
3600        Self {
3601            unique_key,
3602            key_length,
3603            arranged,
3604            cardinality: cardinality.map(std::cmp::Reverse),
3605            filters,
3606            input: std::cmp::Reverse(input),
3607        }
3608    }
3609
3610    /// Turns the instance into a String to be printed in EXPLAIN.
3611    pub fn explain(&self) -> String {
3612        let mut e = "".to_owned();
3613        if self.unique_key {
3614            e.push_str("U");
3615        }
3616        for _ in 0..self.key_length {
3617            e.push_str("K");
3618        }
3619        if self.arranged {
3620            e.push_str("A");
3621        }
3622        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3623            e.push_str(&format!("|{cardinality}|"));
3624        }
3625        e.push_str(&self.filters.explain());
3626        e
3627    }
3628}
3629
3630/// Instructions for finishing the result of a query.
3631///
3632/// The primary reason for the existence of this structure and attendant code
3633/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3634/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3635/// multisets. But as it turns out, the same idea can be used to optimize
3636/// trivial peeks.
3637#[derive(Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3638pub struct RowSetFinishing<L = NonNeg<i64>> {
3639    /// Order rows by the given columns.
3640    pub order_by: Vec<ColumnOrder>,
3641    /// Include only as many rows (after offset).
3642    pub limit: Option<L>,
3643    /// Omit as many rows.
3644    pub offset: usize,
3645    /// Include only given columns.
3646    pub project: Vec<usize>,
3647}
3648
3649impl RustType<ProtoRowSetFinishing> for RowSetFinishing {
3650    fn into_proto(&self) -> ProtoRowSetFinishing {
3651        ProtoRowSetFinishing {
3652            order_by: self.order_by.into_proto(),
3653            limit: self.limit.into_proto(),
3654            offset: self.offset.into_proto(),
3655            project: self.project.into_proto(),
3656        }
3657    }
3658
3659    fn from_proto(x: ProtoRowSetFinishing) -> Result<Self, TryFromProtoError> {
3660        Ok(RowSetFinishing {
3661            order_by: x.order_by.into_rust()?,
3662            limit: x.limit.into_rust()?,
3663            offset: x.offset.into_rust()?,
3664            project: x.project.into_rust()?,
3665        })
3666    }
3667}
3668
3669impl<L> RowSetFinishing<L> {
3670    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3671    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3672        RowSetFinishing {
3673            order_by: Vec::new(),
3674            limit: None,
3675            offset: 0,
3676            project: (0..arity).collect(),
3677        }
3678    }
3679    /// True if the finishing does nothing to any result set.
3680    pub fn is_trivial(&self, arity: usize) -> bool {
3681        self.limit.is_none()
3682            && self.order_by.is_empty()
3683            && self.offset == 0
3684            && self.project.iter().copied().eq(0..arity)
3685    }
3686}
3687
3688impl RowSetFinishing {
3689    /// Applies finishing actions to a [`RowCollection`], and reports the total
3690    /// time it took to run.
3691    ///
3692    /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3693    /// well as the size of the response in bytes.
3694    pub fn finish(
3695        &self,
3696        rows: RowCollection,
3697        max_result_size: u64,
3698        max_returned_query_size: Option<u64>,
3699        duration_histogram: &Histogram,
3700    ) -> Result<(SortedRowCollectionIter, usize), String> {
3701        let now = Instant::now();
3702        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3703        let duration = now.elapsed();
3704        duration_histogram.observe(duration.as_secs_f64());
3705
3706        result
3707    }
3708
3709    /// Implementation for [`RowSetFinishing::finish`].
3710    fn finish_inner(
3711        &self,
3712        rows: RowCollection,
3713        max_result_size: u64,
3714        max_returned_query_size: Option<u64>,
3715    ) -> Result<(SortedRowCollectionIter, usize), String> {
3716        // How much additional memory is required to make a sorted view.
3717        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3718        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3719
3720        // Bail if creating the sorted view would require us to use too much memory.
3721        if required_memory > usize::cast_from(max_result_size) {
3722            let max_bytes = ByteSize::b(max_result_size);
3723            return Err(format!("result exceeds max size of {max_bytes}",));
3724        }
3725
3726        let sorted_view = rows.sorted_view(&self.order_by);
3727        let mut iter = sorted_view
3728            .into_row_iter()
3729            .apply_offset(self.offset)
3730            .with_projection(self.project.clone());
3731
3732        if let Some(limit) = self.limit {
3733            let limit = u64::from(limit);
3734            let limit = usize::cast_from(limit);
3735            iter = iter.with_limit(limit);
3736        };
3737
3738        // TODO(parkmycar): Re-think how we can calculate the total response size without
3739        // having to iterate through the entire collection of Rows, while still
3740        // respecting the LIMIT, OFFSET, and projections.
3741        //
3742        // Note: It feels a bit bad always calculating the response size, but we almost
3743        // always need it to either check the `max_returned_query_size`, or for reporting
3744        // in the query history.
3745        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3746
3747        // Bail if we would end up returning more data to the client than they can support.
3748        if let Some(max) = max_returned_query_size {
3749            if response_size > usize::cast_from(max) {
3750                let max_bytes = ByteSize::b(max);
3751                return Err(format!("result exceeds max size of {max_bytes}"));
3752            }
3753        }
3754
3755        Ok((iter, response_size))
3756    }
3757}
3758
3759/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3760/// ordering, call `tiebreaker`.
3761pub fn compare_columns<F>(
3762    order: &[ColumnOrder],
3763    left: &[Datum],
3764    right: &[Datum],
3765    tiebreaker: F,
3766) -> Ordering
3767where
3768    F: Fn() -> Ordering,
3769{
3770    for order in order {
3771        let cmp = match (&left[order.column], &right[order.column]) {
3772            (Datum::Null, Datum::Null) => Ordering::Equal,
3773            (Datum::Null, _) => {
3774                if order.nulls_last {
3775                    Ordering::Greater
3776                } else {
3777                    Ordering::Less
3778                }
3779            }
3780            (_, Datum::Null) => {
3781                if order.nulls_last {
3782                    Ordering::Less
3783                } else {
3784                    Ordering::Greater
3785                }
3786            }
3787            (lval, rval) => {
3788                if order.desc {
3789                    rval.cmp(lval)
3790                } else {
3791                    lval.cmp(rval)
3792                }
3793            }
3794        };
3795        if cmp != Ordering::Equal {
3796            return cmp;
3797        }
3798    }
3799    tiebreaker()
3800}
3801
3802/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3803/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3804///
3805/// Window frames define a subset of the partition , and only a subset of
3806/// window functions make use of the window frame.
3807#[derive(
3808    Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
3809)]
3810pub struct WindowFrame {
3811    /// ROWS, RANGE or GROUPS
3812    pub units: WindowFrameUnits,
3813    /// Where the frame starts
3814    pub start_bound: WindowFrameBound,
3815    /// Where the frame ends
3816    pub end_bound: WindowFrameBound,
3817}
3818
3819impl Display for WindowFrame {
3820    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3821        write!(
3822            f,
3823            "{} between {} and {}",
3824            self.units, self.start_bound, self.end_bound
3825        )
3826    }
3827}
3828
3829impl WindowFrame {
3830    /// Return the default window frame used when one is not explicitly defined
3831    pub fn default() -> Self {
3832        WindowFrame {
3833            units: WindowFrameUnits::Range,
3834            start_bound: WindowFrameBound::UnboundedPreceding,
3835            end_bound: WindowFrameBound::CurrentRow,
3836        }
3837    }
3838
3839    fn includes_current_row(&self) -> bool {
3840        use WindowFrameBound::*;
3841        match self.start_bound {
3842            UnboundedPreceding => match self.end_bound {
3843                UnboundedPreceding => false,
3844                OffsetPreceding(0) => true,
3845                OffsetPreceding(_) => false,
3846                CurrentRow => true,
3847                OffsetFollowing(_) => true,
3848                UnboundedFollowing => true,
3849            },
3850            OffsetPreceding(0) => match self.end_bound {
3851                UnboundedPreceding => unreachable!(),
3852                OffsetPreceding(0) => true,
3853                // Any nonzero offsets here will create an empty window
3854                OffsetPreceding(_) => false,
3855                CurrentRow => true,
3856                OffsetFollowing(_) => true,
3857                UnboundedFollowing => true,
3858            },
3859            OffsetPreceding(_) => match self.end_bound {
3860                UnboundedPreceding => unreachable!(),
3861                // Window ends at the current row
3862                OffsetPreceding(0) => true,
3863                OffsetPreceding(_) => false,
3864                CurrentRow => true,
3865                OffsetFollowing(_) => true,
3866                UnboundedFollowing => true,
3867            },
3868            CurrentRow => true,
3869            OffsetFollowing(0) => match self.end_bound {
3870                UnboundedPreceding => unreachable!(),
3871                OffsetPreceding(_) => unreachable!(),
3872                CurrentRow => unreachable!(),
3873                OffsetFollowing(_) => true,
3874                UnboundedFollowing => true,
3875            },
3876            OffsetFollowing(_) => match self.end_bound {
3877                UnboundedPreceding => unreachable!(),
3878                OffsetPreceding(_) => unreachable!(),
3879                CurrentRow => unreachable!(),
3880                OffsetFollowing(_) => false,
3881                UnboundedFollowing => false,
3882            },
3883            UnboundedFollowing => false,
3884        }
3885    }
3886}
3887
3888impl RustType<ProtoWindowFrame> for WindowFrame {
3889    fn into_proto(&self) -> ProtoWindowFrame {
3890        ProtoWindowFrame {
3891            units: Some(self.units.into_proto()),
3892            start_bound: Some(self.start_bound.into_proto()),
3893            end_bound: Some(self.end_bound.into_proto()),
3894        }
3895    }
3896
3897    fn from_proto(proto: ProtoWindowFrame) -> Result<Self, TryFromProtoError> {
3898        Ok(WindowFrame {
3899            units: proto.units.into_rust_if_some("ProtoWindowFrame::units")?,
3900            start_bound: proto
3901                .start_bound
3902                .into_rust_if_some("ProtoWindowFrame::start_bound")?,
3903            end_bound: proto
3904                .end_bound
3905                .into_rust_if_some("ProtoWindowFrame::end_bound")?,
3906        })
3907    }
3908}
3909
3910/// Describe how frame bounds are interpreted
3911#[derive(
3912    Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
3913)]
3914pub enum WindowFrameUnits {
3915    /// Each row is treated as the unit of work for bounds
3916    Rows,
3917    /// Each peer group is treated as the unit of work for bounds,
3918    /// and offset-based bounds use the value of the ORDER BY expression
3919    Range,
3920    /// Each peer group is treated as the unit of work for bounds.
3921    /// Groups is currently not supported, and it is rejected during planning.
3922    Groups,
3923}
3924
3925impl Display for WindowFrameUnits {
3926    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3927        match self {
3928            WindowFrameUnits::Rows => write!(f, "rows"),
3929            WindowFrameUnits::Range => write!(f, "range"),
3930            WindowFrameUnits::Groups => write!(f, "groups"),
3931        }
3932    }
3933}
3934
3935impl RustType<proto_window_frame::ProtoWindowFrameUnits> for WindowFrameUnits {
3936    fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameUnits {
3937        use proto_window_frame::proto_window_frame_units::Kind::*;
3938        proto_window_frame::ProtoWindowFrameUnits {
3939            kind: Some(match self {
3940                WindowFrameUnits::Rows => Rows(()),
3941                WindowFrameUnits::Range => Range(()),
3942                WindowFrameUnits::Groups => Groups(()),
3943            }),
3944        }
3945    }
3946
3947    fn from_proto(
3948        proto: proto_window_frame::ProtoWindowFrameUnits,
3949    ) -> Result<Self, TryFromProtoError> {
3950        use proto_window_frame::proto_window_frame_units::Kind::*;
3951        Ok(match proto.kind {
3952            Some(Rows(())) => WindowFrameUnits::Rows,
3953            Some(Range(())) => WindowFrameUnits::Range,
3954            Some(Groups(())) => WindowFrameUnits::Groups,
3955            None => {
3956                return Err(TryFromProtoError::missing_field(
3957                    "ProtoWindowFrameUnits::kind",
3958                ));
3959            }
3960        })
3961    }
3962}
3963
3964/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3965///
3966/// The order between frame bounds is significant, as Postgres enforces
3967/// some restrictions there.
3968#[derive(
3969    Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, MzReflect, PartialOrd, Ord,
3970)]
3971pub enum WindowFrameBound {
3972    /// `UNBOUNDED PRECEDING`
3973    UnboundedPreceding,
3974    /// `<N> PRECEDING`
3975    OffsetPreceding(u64),
3976    /// `CURRENT ROW`
3977    CurrentRow,
3978    /// `<N> FOLLOWING`
3979    OffsetFollowing(u64),
3980    /// `UNBOUNDED FOLLOWING`.
3981    UnboundedFollowing,
3982}
3983
3984impl Display for WindowFrameBound {
3985    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3986        match self {
3987            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
3988            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
3989            WindowFrameBound::CurrentRow => write!(f, "current row"),
3990            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
3991            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
3992        }
3993    }
3994}
3995
3996impl RustType<proto_window_frame::ProtoWindowFrameBound> for WindowFrameBound {
3997    fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameBound {
3998        use proto_window_frame::proto_window_frame_bound::Kind::*;
3999        proto_window_frame::ProtoWindowFrameBound {
4000            kind: Some(match self {
4001                WindowFrameBound::UnboundedPreceding => UnboundedPreceding(()),
4002                WindowFrameBound::OffsetPreceding(offset) => OffsetPreceding(*offset),
4003                WindowFrameBound::CurrentRow => CurrentRow(()),
4004                WindowFrameBound::OffsetFollowing(offset) => OffsetFollowing(*offset),
4005                WindowFrameBound::UnboundedFollowing => UnboundedFollowing(()),
4006            }),
4007        }
4008    }
4009
4010    fn from_proto(x: proto_window_frame::ProtoWindowFrameBound) -> Result<Self, TryFromProtoError> {
4011        use proto_window_frame::proto_window_frame_bound::Kind::*;
4012        Ok(match x.kind {
4013            Some(UnboundedPreceding(())) => WindowFrameBound::UnboundedPreceding,
4014            Some(OffsetPreceding(offset)) => WindowFrameBound::OffsetPreceding(offset),
4015            Some(CurrentRow(())) => WindowFrameBound::CurrentRow,
4016            Some(OffsetFollowing(offset)) => WindowFrameBound::OffsetFollowing(offset),
4017            Some(UnboundedFollowing(())) => WindowFrameBound::UnboundedFollowing,
4018            None => {
4019                return Err(TryFromProtoError::missing_field(
4020                    "ProtoWindowFrameBound::kind",
4021                ));
4022            }
4023        })
4024    }
4025}
4026
4027/// Maximum iterations for a LetRec.
4028#[derive(
4029    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Arbitrary,
4030)]
4031pub struct LetRecLimit {
4032    /// Maximum number of iterations to evaluate.
4033    pub max_iters: NonZeroU64,
4034    /// Whether to throw an error when reaching the above limit.
4035    /// If true, we simply use the current contents of each Id as the final result.
4036    pub return_at_limit: bool,
4037}
4038
4039impl LetRecLimit {
4040    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4041    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4042        limits
4043            .iter()
4044            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4045            .min()
4046    }
4047
4048    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4049    /// WMR without ERROR AT or RETURN AT.
4050    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4051}
4052
4053impl Display for LetRecLimit {
4054    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4055        write!(f, "[recursion_limit={}", self.max_iters)?;
4056        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4057            write!(f, ", return_at_limit")?;
4058        }
4059        write!(f, "]")
4060    }
4061}
4062
4063/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4064/// (See comment in MirRelationExpr::Get.)
4065#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, Arbitrary)]
4066pub enum AccessStrategy {
4067    /// It's either a local Get (a CTE), or unknown at the time.
4068    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4069    /// one of the other variants.
4070    UnknownOrLocal,
4071    /// The Get will read from Persist.
4072    Persist,
4073    /// The Get will read from an index or indexes: (index id, how the index will be used).
4074    Index(Vec<(GlobalId, IndexUsageType)>),
4075    /// The Get will read a collection that is computed by the same dataflow, but in a different
4076    /// `BuildDesc` in `objects_to_build`.
4077    SameDataflow,
4078}
4079
4080#[cfg(test)]
4081mod tests {
4082    use mz_ore::assert_ok;
4083    use mz_proto::protobuf_roundtrip;
4084    use mz_repr::explain::text::text_string_at;
4085    use proptest::prelude::*;
4086
4087    use crate::explain::HumanizedExplain;
4088
4089    use super::*;
4090
4091    proptest! {
4092        #[mz_ore::test]
4093        fn column_order_protobuf_roundtrip(expect in any::<ColumnOrder>()) {
4094            let actual = protobuf_roundtrip::<_, ProtoColumnOrder>(&expect);
4095            assert_ok!(actual);
4096            assert_eq!(actual.unwrap(), expect);
4097        }
4098    }
4099
4100    proptest! {
4101        #[mz_ore::test]
4102        #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
4103        fn aggregate_expr_protobuf_roundtrip(expect in any::<AggregateExpr>()) {
4104            let actual = protobuf_roundtrip::<_, ProtoAggregateExpr>(&expect);
4105            assert_ok!(actual);
4106            assert_eq!(actual.unwrap(), expect);
4107        }
4108    }
4109
4110    proptest! {
4111        #[mz_ore::test]
4112        fn window_frame_units_protobuf_roundtrip(expect in any::<WindowFrameUnits>()) {
4113            let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameUnits>(&expect);
4114            assert_ok!(actual);
4115            assert_eq!(actual.unwrap(), expect);
4116        }
4117    }
4118
4119    proptest! {
4120        #[mz_ore::test]
4121        fn window_frame_bound_protobuf_roundtrip(expect in any::<WindowFrameBound>()) {
4122            let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameBound>(&expect);
4123            assert_ok!(actual);
4124            assert_eq!(actual.unwrap(), expect);
4125        }
4126    }
4127
4128    proptest! {
4129        #[mz_ore::test]
4130        fn window_frame_protobuf_roundtrip(expect in any::<WindowFrame>()) {
4131            let actual = protobuf_roundtrip::<_, ProtoWindowFrame>(&expect);
4132            assert_ok!(actual);
4133            assert_eq!(actual.unwrap(), expect);
4134        }
4135    }
4136
4137    #[mz_ore::test]
4138    fn test_row_set_finishing_as_text() {
4139        let finishing = RowSetFinishing {
4140            order_by: vec![ColumnOrder {
4141                column: 4,
4142                desc: true,
4143                nulls_last: true,
4144            }],
4145            limit: Some(NonNeg::try_from(7).unwrap()),
4146            offset: Default::default(),
4147            project: vec![1, 3, 4, 5],
4148        };
4149
4150        let mode = HumanizedExplain::new(false);
4151        let expr = mode.expr(&finishing, None);
4152
4153        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4154
4155        let exp = {
4156            use mz_ore::fmt::FormatBuffer;
4157            let mut s = String::new();
4158            write!(&mut s, "Finish");
4159            write!(&mut s, " order_by=[#4 desc nulls_last]");
4160            write!(&mut s, " limit=7");
4161            write!(&mut s, " output=[#1, #3..=#5]");
4162            writeln!(&mut s, "");
4163            s
4164        };
4165
4166        assert_eq!(act, exp);
4167    }
4168}
4169
4170/// An iterator over AST structures, which calls out nodes in difference.
4171///
4172/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4173/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4174/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4175/// to compare two ASTs.
4176mod structured_diff {
4177
4178    use super::MirRelationExpr;
4179
4180    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4181    pub struct MreDiff<'a> {
4182        /// Pairs of expressions that must still be compared.
4183        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4184    }
4185
4186    impl<'a> MreDiff<'a> {
4187        /// Create a new `MirRelationExpr` structured difference.
4188        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4189            MreDiff {
4190                todo: vec![(expr1, expr2)],
4191            }
4192        }
4193    }
4194
4195    impl<'a> Iterator for MreDiff<'a> {
4196        // Pairs of expressions that do not match.
4197        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4198
4199        fn next(&mut self) -> Option<Self::Item> {
4200            while let Some((expr1, expr2)) = self.todo.pop() {
4201                match (expr1, expr2) {
4202                    (
4203                        MirRelationExpr::Constant {
4204                            rows: rows1,
4205                            typ: typ1,
4206                        },
4207                        MirRelationExpr::Constant {
4208                            rows: rows2,
4209                            typ: typ2,
4210                        },
4211                    ) => {
4212                        if rows1 != rows2 || typ1 != typ2 {
4213                            return Some((expr1, expr2));
4214                        }
4215                    }
4216                    (
4217                        MirRelationExpr::Get {
4218                            id: id1,
4219                            typ: typ1,
4220                            access_strategy: as1,
4221                        },
4222                        MirRelationExpr::Get {
4223                            id: id2,
4224                            typ: typ2,
4225                            access_strategy: as2,
4226                        },
4227                    ) => {
4228                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4229                            return Some((expr1, expr2));
4230                        }
4231                    }
4232                    (
4233                        MirRelationExpr::Let {
4234                            id: id1,
4235                            body: body1,
4236                            value: value1,
4237                        },
4238                        MirRelationExpr::Let {
4239                            id: id2,
4240                            body: body2,
4241                            value: value2,
4242                        },
4243                    ) => {
4244                        if id1 != id2 {
4245                            return Some((expr1, expr2));
4246                        } else {
4247                            self.todo.push((body1, body2));
4248                            self.todo.push((value1, value2));
4249                        }
4250                    }
4251                    (
4252                        MirRelationExpr::LetRec {
4253                            ids: ids1,
4254                            body: body1,
4255                            values: values1,
4256                            limits: limits1,
4257                        },
4258                        MirRelationExpr::LetRec {
4259                            ids: ids2,
4260                            body: body2,
4261                            values: values2,
4262                            limits: limits2,
4263                        },
4264                    ) => {
4265                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4266                            return Some((expr1, expr2));
4267                        } else {
4268                            self.todo.push((body1, body2));
4269                            self.todo.extend(values1.iter().zip(values2.iter()));
4270                        }
4271                    }
4272                    (
4273                        MirRelationExpr::Project {
4274                            outputs: outputs1,
4275                            input: input1,
4276                        },
4277                        MirRelationExpr::Project {
4278                            outputs: outputs2,
4279                            input: input2,
4280                        },
4281                    ) => {
4282                        if outputs1 != outputs2 {
4283                            return Some((expr1, expr2));
4284                        } else {
4285                            self.todo.push((input1, input2));
4286                        }
4287                    }
4288                    (
4289                        MirRelationExpr::Map {
4290                            scalars: scalars1,
4291                            input: input1,
4292                        },
4293                        MirRelationExpr::Map {
4294                            scalars: scalars2,
4295                            input: input2,
4296                        },
4297                    ) => {
4298                        if scalars1 != scalars2 {
4299                            return Some((expr1, expr2));
4300                        } else {
4301                            self.todo.push((input1, input2));
4302                        }
4303                    }
4304                    (
4305                        MirRelationExpr::Filter {
4306                            predicates: predicates1,
4307                            input: input1,
4308                        },
4309                        MirRelationExpr::Filter {
4310                            predicates: predicates2,
4311                            input: input2,
4312                        },
4313                    ) => {
4314                        if predicates1 != predicates2 {
4315                            return Some((expr1, expr2));
4316                        } else {
4317                            self.todo.push((input1, input2));
4318                        }
4319                    }
4320                    (
4321                        MirRelationExpr::FlatMap {
4322                            input: input1,
4323                            func: func1,
4324                            exprs: exprs1,
4325                        },
4326                        MirRelationExpr::FlatMap {
4327                            input: input2,
4328                            func: func2,
4329                            exprs: exprs2,
4330                        },
4331                    ) => {
4332                        if func1 != func2 || exprs1 != exprs2 {
4333                            return Some((expr1, expr2));
4334                        } else {
4335                            self.todo.push((input1, input2));
4336                        }
4337                    }
4338                    (
4339                        MirRelationExpr::Join {
4340                            inputs: inputs1,
4341                            equivalences: eq1,
4342                            implementation: impl1,
4343                        },
4344                        MirRelationExpr::Join {
4345                            inputs: inputs2,
4346                            equivalences: eq2,
4347                            implementation: impl2,
4348                        },
4349                    ) => {
4350                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4351                            return Some((expr1, expr2));
4352                        } else {
4353                            self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4354                        }
4355                    }
4356                    (
4357                        MirRelationExpr::Reduce {
4358                            aggregates: aggregates1,
4359                            input: inputs1,
4360                            group_key: gk1,
4361                            monotonic: m1,
4362                            expected_group_size: egs1,
4363                        },
4364                        MirRelationExpr::Reduce {
4365                            aggregates: aggregates2,
4366                            input: inputs2,
4367                            group_key: gk2,
4368                            monotonic: m2,
4369                            expected_group_size: egs2,
4370                        },
4371                    ) => {
4372                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4373                            return Some((expr1, expr2));
4374                        } else {
4375                            self.todo.push((inputs1, inputs2));
4376                        }
4377                    }
4378                    (
4379                        MirRelationExpr::TopK {
4380                            group_key: gk1,
4381                            order_key: order1,
4382                            input: input1,
4383                            limit: l1,
4384                            offset: o1,
4385                            monotonic: m1,
4386                            expected_group_size: egs1,
4387                        },
4388                        MirRelationExpr::TopK {
4389                            group_key: gk2,
4390                            order_key: order2,
4391                            input: input2,
4392                            limit: l2,
4393                            offset: o2,
4394                            monotonic: m2,
4395                            expected_group_size: egs2,
4396                        },
4397                    ) => {
4398                        if order1 != order2
4399                            || gk1 != gk2
4400                            || l1 != l2
4401                            || o1 != o2
4402                            || m1 != m2
4403                            || egs1 != egs2
4404                        {
4405                            return Some((expr1, expr2));
4406                        } else {
4407                            self.todo.push((input1, input2));
4408                        }
4409                    }
4410                    (
4411                        MirRelationExpr::Negate { input: input1 },
4412                        MirRelationExpr::Negate { input: input2 },
4413                    ) => {
4414                        self.todo.push((input1, input2));
4415                    }
4416                    (
4417                        MirRelationExpr::Threshold { input: input1 },
4418                        MirRelationExpr::Threshold { input: input2 },
4419                    ) => {
4420                        self.todo.push((input1, input2));
4421                    }
4422                    (
4423                        MirRelationExpr::Union {
4424                            base: base1,
4425                            inputs: inputs1,
4426                        },
4427                        MirRelationExpr::Union {
4428                            base: base2,
4429                            inputs: inputs2,
4430                        },
4431                    ) => {
4432                        if inputs1.len() != inputs2.len() {
4433                            return Some((expr1, expr2));
4434                        } else {
4435                            self.todo.push((base1, base2));
4436                            self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4437                        }
4438                    }
4439                    (
4440                        MirRelationExpr::ArrangeBy {
4441                            keys: keys1,
4442                            input: input1,
4443                        },
4444                        MirRelationExpr::ArrangeBy {
4445                            keys: keys2,
4446                            input: input2,
4447                        },
4448                    ) => {
4449                        if keys1 != keys2 {
4450                            return Some((expr1, expr2));
4451                        } else {
4452                            self.todo.push((input1, input2));
4453                        }
4454                    }
4455                    _ => {
4456                        return Some((expr1, expr2));
4457                    }
4458                }
4459            }
4460            None
4461        }
4462    }
4463}