mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::{fmt, mem};
17
18use itertools::Itertools;
19use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
20use mz_expr::visit::{Visit, VisitChildren};
21use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
22// these happen to be unchanged at the moment, but there might be additions later
23use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
24pub use mz_expr::{
25    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
26};
27use mz_ore::collections::CollectionExt;
28use mz_ore::stack;
29use mz_ore::stack::RecursionLimitError;
30use mz_ore::str::separated;
31use mz_repr::adt::array::ArrayDimension;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::*;
34use serde::{Deserialize, Serialize};
35
36use crate::plan::Params;
37use crate::plan::error::PlanError;
38use crate::plan::query::ExprContext;
39use crate::plan::typeconv::{self, CastContext};
40
41use super::plan_utils::GroupSizeHints;
42
43#[allow(missing_debug_implementations)]
44pub struct Hir;
45
46impl IR for Hir {
47    type Relation = HirRelationExpr;
48    type Scalar = HirScalarExpr;
49}
50
51impl AlgExcept for Hir {
52    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
53        if *all {
54            let rhs = rhs.negate();
55            HirRelationExpr::union(lhs, rhs).threshold()
56        } else {
57            let lhs = lhs.distinct();
58            let rhs = rhs.distinct().negate();
59            HirRelationExpr::union(lhs, rhs).threshold()
60        }
61    }
62
63    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
64        let mut result = None;
65
66        use HirRelationExpr::*;
67        if let Threshold { input } = expr {
68            if let Union { base: lhs, inputs } = input.as_ref() {
69                if let [rhs] = &inputs[..] {
70                    if let Negate { input: rhs } = rhs {
71                        match (lhs.as_ref(), rhs.as_ref()) {
72                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
73                                let all = false;
74                                let lhs = lhs.as_ref();
75                                let rhs = rhs.as_ref();
76                                result = Some(Except { all, lhs, rhs })
77                            }
78                            (lhs, rhs) => {
79                                let all = true;
80                                result = Some(Except { all, lhs, rhs })
81                            }
82                        }
83                    }
84                }
85            }
86        }
87
88        result
89    }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
93/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
94pub enum HirRelationExpr {
95    Constant {
96        rows: Vec<Row>,
97        typ: RelationType,
98    },
99    Get {
100        id: mz_expr::Id,
101        typ: RelationType,
102    },
103    /// Mutually recursive CTE
104    LetRec {
105        /// Maximum number of iterations to evaluate. If None, then there is no limit.
106        limit: Option<LetRecLimit>,
107        /// List of bindings all of which are in scope of each other.
108        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, RelationType)>,
109        /// Result of the AST node.
110        body: Box<HirRelationExpr>,
111    },
112    /// CTE
113    Let {
114        name: String,
115        /// The identifier to be used in `Get` variants to retrieve `value`.
116        id: mz_expr::LocalId,
117        /// The collection to be bound to `name`.
118        value: Box<HirRelationExpr>,
119        /// The result of the `Let`, evaluated with `name` bound to `value`.
120        body: Box<HirRelationExpr>,
121    },
122    Project {
123        input: Box<HirRelationExpr>,
124        outputs: Vec<usize>,
125    },
126    Map {
127        input: Box<HirRelationExpr>,
128        scalars: Vec<HirScalarExpr>,
129    },
130    CallTable {
131        func: TableFunc,
132        exprs: Vec<HirScalarExpr>,
133    },
134    Filter {
135        input: Box<HirRelationExpr>,
136        predicates: Vec<HirScalarExpr>,
137    },
138    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
139    /// joins away into more primitive exprs
140    Join {
141        left: Box<HirRelationExpr>,
142        right: Box<HirRelationExpr>,
143        on: HirScalarExpr,
144        kind: JoinKind,
145    },
146    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
147    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
148    /// rows
149    Reduce {
150        input: Box<HirRelationExpr>,
151        group_key: Vec<usize>,
152        aggregates: Vec<AggregateExpr>,
153        expected_group_size: Option<u64>,
154    },
155    Distinct {
156        input: Box<HirRelationExpr>,
157    },
158    /// Groups and orders within each group, limiting output.
159    TopK {
160        /// The source collection.
161        input: Box<HirRelationExpr>,
162        /// Column indices used to form groups.
163        group_key: Vec<usize>,
164        /// Column indices used to order rows within groups.
165        order_key: Vec<ColumnOrder>,
166        /// Number of records to retain
167        limit: Option<HirScalarExpr>,
168        /// Number of records to skip
169        offset: usize,
170        /// User-supplied hint: how many rows will have the same group key.
171        expected_group_size: Option<u64>,
172    },
173    Negate {
174        input: Box<HirRelationExpr>,
175    },
176    /// Keep rows from a dataflow where the row counts are positive.
177    Threshold {
178        input: Box<HirRelationExpr>,
179    },
180    Union {
181        base: Box<HirRelationExpr>,
182        inputs: Vec<HirRelationExpr>,
183    },
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
187/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
188pub enum HirScalarExpr {
189    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
190    /// variable could refer to a column of the current input, or to a column of an outer relation.
191    /// We use ColumnRef to denote the difference.
192    Column(ColumnRef),
193    Parameter(usize),
194    Literal(Row, ColumnType),
195    CallUnmaterializable(UnmaterializableFunc),
196    CallUnary {
197        func: UnaryFunc,
198        expr: Box<HirScalarExpr>,
199    },
200    CallBinary {
201        func: BinaryFunc,
202        expr1: Box<HirScalarExpr>,
203        expr2: Box<HirScalarExpr>,
204    },
205    CallVariadic {
206        func: VariadicFunc,
207        exprs: Vec<HirScalarExpr>,
208    },
209    If {
210        cond: Box<HirScalarExpr>,
211        then: Box<HirScalarExpr>,
212        els: Box<HirScalarExpr>,
213    },
214    /// Returns true if `expr` returns any rows
215    Exists(Box<HirRelationExpr>),
216    /// Given `expr` with arity 1. If expr returns:
217    /// * 0 rows, return NULL
218    /// * 1 row, return the value of that row
219    /// * >1 rows, the sql spec says we should throw an error but we can't
220    ///   (see <https://github.com/MaterializeInc/database-issues/issues/154>)
221    ///   so instead we return all the rows.
222    ///   If there are multiple `Select` expressions in a single SQL query, the result is that we take the product of all of them.
223    ///   This is counter to the spec, but is consistent with eg postgres' treatment of multiple set-returning-functions
224    ///   (see <https://tapoueh.org/blog/2017/10/set-returning-functions-and-postgresql-10/>).
225    Select(Box<HirRelationExpr>),
226    Windowing(WindowExpr),
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
230/// Represents the invocation of a window function over an optional partitioning with an optional
231/// order.
232pub struct WindowExpr {
233    pub func: WindowExprType,
234    pub partition_by: Vec<HirScalarExpr>,
235    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
236    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
237    ///    above,
238    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
239    /// These are separated because they are used in different places: the outer `order_by` is used
240    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
241    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
242    /// (`WindowExpr` exists only in HIR, but not in MIR.)
243    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
244    /// lowering, and not to original input columns.
245    pub order_by: Vec<HirScalarExpr>,
246}
247
248impl WindowExpr {
249    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
250    where
251        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
252    {
253        #[allow(deprecated)]
254        self.func.visit_expressions(f)?;
255        for expr in self.partition_by.iter() {
256            f(expr)?;
257        }
258        for expr in self.order_by.iter() {
259            f(expr)?;
260        }
261        Ok(())
262    }
263
264    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
265    where
266        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
267    {
268        #[allow(deprecated)]
269        self.func.visit_expressions_mut(f)?;
270        for expr in self.partition_by.iter_mut() {
271            f(expr)?;
272        }
273        for expr in self.order_by.iter_mut() {
274            f(expr)?;
275        }
276        Ok(())
277    }
278}
279
280impl VisitChildren<HirScalarExpr> for WindowExpr {
281    fn visit_children<F>(&self, mut f: F)
282    where
283        F: FnMut(&HirScalarExpr),
284    {
285        self.func.visit_children(&mut f);
286        for expr in self.partition_by.iter() {
287            f(expr);
288        }
289        for expr in self.order_by.iter() {
290            f(expr);
291        }
292    }
293
294    fn visit_mut_children<F>(&mut self, mut f: F)
295    where
296        F: FnMut(&mut HirScalarExpr),
297    {
298        self.func.visit_mut_children(&mut f);
299        for expr in self.partition_by.iter_mut() {
300            f(expr);
301        }
302        for expr in self.order_by.iter_mut() {
303            f(expr);
304        }
305    }
306
307    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
308    where
309        F: FnMut(&HirScalarExpr) -> Result<(), E>,
310        E: From<RecursionLimitError>,
311    {
312        self.func.try_visit_children(&mut f)?;
313        for expr in self.partition_by.iter() {
314            f(expr)?;
315        }
316        for expr in self.order_by.iter() {
317            f(expr)?;
318        }
319        Ok(())
320    }
321
322    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
323    where
324        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
325        E: From<RecursionLimitError>,
326    {
327        self.func.try_visit_mut_children(&mut f)?;
328        for expr in self.partition_by.iter_mut() {
329            f(expr)?;
330        }
331        for expr in self.order_by.iter_mut() {
332            f(expr)?;
333        }
334        Ok(())
335    }
336}
337
338#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
339/// A window function with its parameters.
340///
341/// There are three types of window functions:
342/// - scalar window functions, which return a different scalar value for each
343///   row within a partition that depends exclusively on the position of the row
344///   within the partition;
345/// - value window functions, which return a scalar value for each row within a
346///   partition that might be computed based on a single row, which is usually not
347///   the current row (e.g., previous or following row; first or last row of the
348///   partition);
349/// - aggregate window functions, which compute a traditional aggregation as a
350///   window function (e.g. `sum(x) OVER (...)`).
351///   (Aggregate window  functions can in some cases be computed by joining the
352///   input relation with a reduction over the same relation that computes the
353///   aggregation using the partition key as its grouping key, but we don't
354///   automatically do this currently.)
355pub enum WindowExprType {
356    Scalar(ScalarWindowExpr),
357    Value(ValueWindowExpr),
358    Aggregate(AggregateWindowExpr),
359}
360
361impl WindowExprType {
362    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
363    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
364    where
365        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
366    {
367        #[allow(deprecated)]
368        match self {
369            Self::Scalar(expr) => expr.visit_expressions(f),
370            Self::Value(expr) => expr.visit_expressions(f),
371            Self::Aggregate(expr) => expr.visit_expressions(f),
372        }
373    }
374
375    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
376    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
377    where
378        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
379    {
380        #[allow(deprecated)]
381        match self {
382            Self::Scalar(expr) => expr.visit_expressions_mut(f),
383            Self::Value(expr) => expr.visit_expressions_mut(f),
384            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
385        }
386    }
387
388    fn typ(
389        &self,
390        outers: &[RelationType],
391        inner: &RelationType,
392        params: &BTreeMap<usize, ScalarType>,
393    ) -> ColumnType {
394        match self {
395            Self::Scalar(expr) => expr.typ(outers, inner, params),
396            Self::Value(expr) => expr.typ(outers, inner, params),
397            Self::Aggregate(expr) => expr.typ(outers, inner, params),
398        }
399    }
400}
401
402impl VisitChildren<HirScalarExpr> for WindowExprType {
403    fn visit_children<F>(&self, f: F)
404    where
405        F: FnMut(&HirScalarExpr),
406    {
407        match self {
408            Self::Scalar(_) => (),
409            Self::Value(expr) => expr.visit_children(f),
410            Self::Aggregate(expr) => expr.visit_children(f),
411        }
412    }
413
414    fn visit_mut_children<F>(&mut self, f: F)
415    where
416        F: FnMut(&mut HirScalarExpr),
417    {
418        match self {
419            Self::Scalar(_) => (),
420            Self::Value(expr) => expr.visit_mut_children(f),
421            Self::Aggregate(expr) => expr.visit_mut_children(f),
422        }
423    }
424
425    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
426    where
427        F: FnMut(&HirScalarExpr) -> Result<(), E>,
428        E: From<RecursionLimitError>,
429    {
430        match self {
431            Self::Scalar(_) => Ok(()),
432            Self::Value(expr) => expr.try_visit_children(f),
433            Self::Aggregate(expr) => expr.try_visit_children(f),
434        }
435    }
436
437    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
438    where
439        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
440        E: From<RecursionLimitError>,
441    {
442        match self {
443            Self::Scalar(_) => Ok(()),
444            Self::Value(expr) => expr.try_visit_mut_children(f),
445            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
446        }
447    }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
451pub struct ScalarWindowExpr {
452    pub func: ScalarWindowFunc,
453    pub order_by: Vec<ColumnOrder>,
454}
455
456impl ScalarWindowExpr {
457    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
458    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
459    where
460        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
461    {
462        match self.func {
463            ScalarWindowFunc::RowNumber => {}
464            ScalarWindowFunc::Rank => {}
465            ScalarWindowFunc::DenseRank => {}
466        }
467        Ok(())
468    }
469
470    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
471    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
472    where
473        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
474    {
475        match self.func {
476            ScalarWindowFunc::RowNumber => {}
477            ScalarWindowFunc::Rank => {}
478            ScalarWindowFunc::DenseRank => {}
479        }
480        Ok(())
481    }
482
483    fn typ(
484        &self,
485        _outers: &[RelationType],
486        _inner: &RelationType,
487        _params: &BTreeMap<usize, ScalarType>,
488    ) -> ColumnType {
489        self.func.output_type()
490    }
491
492    pub fn into_expr(self) -> mz_expr::AggregateFunc {
493        match self.func {
494            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
495                order_by: self.order_by,
496            },
497            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
498                order_by: self.order_by,
499            },
500            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
501                order_by: self.order_by,
502            },
503        }
504    }
505}
506
507#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
508/// Scalar Window functions
509pub enum ScalarWindowFunc {
510    RowNumber,
511    Rank,
512    DenseRank,
513}
514
515impl Display for ScalarWindowFunc {
516    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
517        match self {
518            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
519            ScalarWindowFunc::Rank => write!(f, "rank"),
520            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
521        }
522    }
523}
524
525impl ScalarWindowFunc {
526    pub fn output_type(&self) -> ColumnType {
527        match self {
528            ScalarWindowFunc::RowNumber => ScalarType::Int64.nullable(false),
529            ScalarWindowFunc::Rank => ScalarType::Int64.nullable(false),
530            ScalarWindowFunc::DenseRank => ScalarType::Int64.nullable(false),
531        }
532    }
533}
534
535#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
536pub struct ValueWindowExpr {
537    pub func: ValueWindowFunc,
538    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
539    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
540    /// e.g., `row(#1, 3, null)`.
541    /// If it's a fused window function, then the arguments of each of the constituent function
542    /// calls are wrapped in an outer record.
543    pub args: Box<HirScalarExpr>,
544    /// See comment on `WindowExpr::order_by`.
545    pub order_by: Vec<ColumnOrder>,
546    pub window_frame: WindowFrame,
547    pub ignore_nulls: bool,
548}
549
550impl Display for ValueWindowFunc {
551    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
552        match self {
553            ValueWindowFunc::Lag => write!(f, "lag"),
554            ValueWindowFunc::Lead => write!(f, "lead"),
555            ValueWindowFunc::FirstValue => write!(f, "first_value"),
556            ValueWindowFunc::LastValue => write!(f, "last_value"),
557            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
558        }
559    }
560}
561
562impl ValueWindowExpr {
563    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
564    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
565    where
566        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
567    {
568        f(&self.args)
569    }
570
571    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
572    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
573    where
574        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
575    {
576        f(&mut self.args)
577    }
578
579    fn typ(
580        &self,
581        outers: &[RelationType],
582        inner: &RelationType,
583        params: &BTreeMap<usize, ScalarType>,
584    ) -> ColumnType {
585        self.func.output_type(self.args.typ(outers, inner, params))
586    }
587
588    /// Converts into `mz_expr::AggregateFunc`.
589    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
590        (
591            self.args,
592            self.func
593                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
594        )
595    }
596}
597
598impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
599    fn visit_children<F>(&self, mut f: F)
600    where
601        F: FnMut(&HirScalarExpr),
602    {
603        f(&self.args)
604    }
605
606    fn visit_mut_children<F>(&mut self, mut f: F)
607    where
608        F: FnMut(&mut HirScalarExpr),
609    {
610        f(&mut self.args)
611    }
612
613    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
614    where
615        F: FnMut(&HirScalarExpr) -> Result<(), E>,
616        E: From<RecursionLimitError>,
617    {
618        f(&self.args)
619    }
620
621    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
622    where
623        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
624        E: From<RecursionLimitError>,
625    {
626        f(&mut self.args)
627    }
628}
629
630#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
631/// Value Window functions
632pub enum ValueWindowFunc {
633    Lag,
634    Lead,
635    FirstValue,
636    LastValue,
637    Fused(Vec<ValueWindowFunc>),
638}
639
640impl ValueWindowFunc {
641    pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
642        match self {
643            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
644                // The input is a (value, offset, default) record, so extract the type of the first arg
645                input_type.scalar_type.unwrap_record_element_type()[0]
646                    .clone()
647                    .nullable(true)
648            }
649            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
650                input_type.scalar_type.nullable(true)
651            }
652            ValueWindowFunc::Fused(funcs) => {
653                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
654                ScalarType::Record {
655                    fields: funcs
656                        .iter()
657                        .zip_eq(input_types)
658                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
659                        .collect(),
660                    custom_id: None,
661                }
662                .nullable(false)
663            }
664        }
665    }
666
667    pub fn into_expr(
668        self,
669        order_by: Vec<ColumnOrder>,
670        window_frame: WindowFrame,
671        ignore_nulls: bool,
672    ) -> mz_expr::AggregateFunc {
673        match self {
674            // Lag and Lead are fundamentally the same function, just with opposite directions
675            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
676                order_by,
677                lag_lead: mz_expr::LagLeadType::Lag,
678                ignore_nulls,
679            },
680            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
681                order_by,
682                lag_lead: mz_expr::LagLeadType::Lead,
683                ignore_nulls,
684            },
685            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
686                order_by,
687                window_frame,
688            },
689            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
690                order_by,
691                window_frame,
692            },
693            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
694                funcs: funcs
695                    .into_iter()
696                    .map(|func| {
697                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
698                    })
699                    .collect(),
700                order_by,
701            },
702        }
703    }
704}
705
706#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
707pub struct AggregateWindowExpr {
708    pub aggregate_expr: AggregateExpr,
709    pub order_by: Vec<ColumnOrder>,
710    pub window_frame: WindowFrame,
711}
712
713impl AggregateWindowExpr {
714    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
715    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
716    where
717        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
718    {
719        f(&self.aggregate_expr.expr)
720    }
721
722    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
723    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
724    where
725        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
726    {
727        f(&mut self.aggregate_expr.expr)
728    }
729
730    fn typ(
731        &self,
732        outers: &[RelationType],
733        inner: &RelationType,
734        params: &BTreeMap<usize, ScalarType>,
735    ) -> ColumnType {
736        self.aggregate_expr
737            .func
738            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
739    }
740
741    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
742        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
743            (
744                self.aggregate_expr.expr,
745                FusedWindowAggregate {
746                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
747                    order_by: self.order_by,
748                    window_frame: self.window_frame,
749                },
750            )
751        } else {
752            (
753                self.aggregate_expr.expr,
754                WindowAggregate {
755                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
756                    order_by: self.order_by,
757                    window_frame: self.window_frame,
758                },
759            )
760        }
761    }
762}
763
764impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
765    fn visit_children<F>(&self, mut f: F)
766    where
767        F: FnMut(&HirScalarExpr),
768    {
769        f(&self.aggregate_expr.expr)
770    }
771
772    fn visit_mut_children<F>(&mut self, mut f: F)
773    where
774        F: FnMut(&mut HirScalarExpr),
775    {
776        f(&mut self.aggregate_expr.expr)
777    }
778
779    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
780    where
781        F: FnMut(&HirScalarExpr) -> Result<(), E>,
782        E: From<RecursionLimitError>,
783    {
784        f(&self.aggregate_expr.expr)
785    }
786
787    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
788    where
789        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
790        E: From<RecursionLimitError>,
791    {
792        f(&mut self.aggregate_expr.expr)
793    }
794}
795
796/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
797/// determined. Several SQL expressions can be freely coerced based upon where
798/// in the expression tree they appear. For example, the string literal '42'
799/// will be automatically coerced to the integer 42 if used in a numeric
800/// context:
801///
802/// ```sql
803/// SELECT '42' + 42
804/// ```
805///
806/// This separate type gives the code that needs to interact with coercions very
807/// fine-grained control over what coercions happen and when.
808///
809/// The primary driver of coercion is function and operator selection, as
810/// choosing the correct function or operator implementation depends on the type
811/// of the provided arguments. Coercion also occurs at the very root of the
812/// scalar expression tree. For example in
813///
814/// ```sql
815/// SELECT ... WHERE $1
816/// ```
817///
818/// the `WHERE` clause will coerce the contained unconstrained type parameter
819/// `$1` to have type bool.
820#[derive(Clone, Debug)]
821pub enum CoercibleScalarExpr {
822    Coerced(HirScalarExpr),
823    Parameter(usize),
824    LiteralNull,
825    LiteralString(String),
826    LiteralRecord(Vec<CoercibleScalarExpr>),
827}
828
829impl CoercibleScalarExpr {
830    pub fn type_as(self, ecx: &ExprContext, ty: &ScalarType) -> Result<HirScalarExpr, PlanError> {
831        let expr = typeconv::plan_coerce(ecx, self, ty)?;
832        let expr_ty = ecx.scalar_type(&expr);
833        if ty != &expr_ty {
834            sql_bail!(
835                "{} must have type {}, not type {}",
836                ecx.name,
837                ecx.humanize_scalar_type(ty, false),
838                ecx.humanize_scalar_type(&expr_ty, false),
839            );
840        }
841        Ok(expr)
842    }
843
844    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
845        typeconv::plan_coerce(ecx, self, &ScalarType::String)
846    }
847
848    pub fn cast_to(
849        self,
850        ecx: &ExprContext,
851        ccx: CastContext,
852        ty: &ScalarType,
853    ) -> Result<HirScalarExpr, PlanError> {
854        let expr = typeconv::plan_coerce(ecx, self, ty)?;
855        typeconv::plan_cast(ecx, ccx, expr, ty)
856    }
857}
858
859/// The column type for a [`CoercibleScalarExpr`].
860#[derive(Clone, Debug)]
861pub enum CoercibleColumnType {
862    Coerced(ColumnType),
863    Record(Vec<CoercibleColumnType>),
864    Uncoerced,
865}
866
867impl CoercibleColumnType {
868    /// Reports the nullability of the type.
869    pub fn nullable(&self) -> bool {
870        match self {
871            // A coerced value's nullability is known.
872            CoercibleColumnType::Coerced(ct) => ct.nullable,
873
874            // A literal record can never be null.
875            CoercibleColumnType::Record(_) => false,
876
877            // An uncoerced literal may be the literal `NULL`, so we have
878            // to conservatively assume it is nullable.
879            CoercibleColumnType::Uncoerced => true,
880        }
881    }
882}
883
884/// The scalar type for a [`CoercibleScalarExpr`].
885#[derive(Clone, Debug)]
886pub enum CoercibleScalarType {
887    Coerced(ScalarType),
888    Record(Vec<CoercibleColumnType>),
889    Uncoerced,
890}
891
892impl CoercibleScalarType {
893    /// Reports whether the scalar type has been coerced.
894    pub fn is_coerced(&self) -> bool {
895        matches!(self, CoercibleScalarType::Coerced(_))
896    }
897
898    /// Returns the coerced scalar type, if the type is coerced.
899    pub fn as_coerced(&self) -> Option<&ScalarType> {
900        match self {
901            CoercibleScalarType::Coerced(t) => Some(t),
902            _ => None,
903        }
904    }
905
906    /// If the type is coerced, apply the mapping function to the contained
907    /// scalar type.
908    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
909    where
910        F: FnOnce(ScalarType) -> ScalarType,
911    {
912        match self {
913            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
914            _ => self,
915        }
916    }
917
918    /// If the type is an coercible record, forcibly converts to a coerced
919    /// record type. Any uncoerced field types are assumed to be of type text.
920    ///
921    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
922    /// accepts a type hint that can indicate the types of uncoerced field
923    /// types.
924    pub fn force_coerced_if_record(&mut self) {
925        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> ScalarType {
926            let mut fields = vec![];
927            for (i, uf) in uncoerced_fields.enumerate() {
928                let name = ColumnName::from(format!("f{}", i + 1));
929                let ty = match uf {
930                    CoercibleColumnType::Coerced(ty) => ty,
931                    CoercibleColumnType::Record(mut fields) => {
932                        convert(fields.drain(..)).nullable(false)
933                    }
934                    CoercibleColumnType::Uncoerced => ScalarType::String.nullable(true),
935                };
936                fields.push((name, ty))
937            }
938            ScalarType::Record {
939                fields: fields.into(),
940                custom_id: None,
941            }
942        }
943
944        if let CoercibleScalarType::Record(fields) = self {
945            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
946        }
947    }
948}
949
950/// An expression whose type can be ascertained.
951///
952/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
953pub trait AbstractExpr {
954    type Type: AbstractColumnType;
955
956    /// Computes the type of the expression.
957    fn typ(
958        &self,
959        outers: &[RelationType],
960        inner: &RelationType,
961        params: &BTreeMap<usize, ScalarType>,
962    ) -> Self::Type;
963}
964
965impl AbstractExpr for CoercibleScalarExpr {
966    type Type = CoercibleColumnType;
967
968    fn typ(
969        &self,
970        outers: &[RelationType],
971        inner: &RelationType,
972        params: &BTreeMap<usize, ScalarType>,
973    ) -> Self::Type {
974        match self {
975            CoercibleScalarExpr::Coerced(expr) => {
976                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
977            }
978            CoercibleScalarExpr::LiteralRecord(scalars) => {
979                let fields = scalars
980                    .iter()
981                    .map(|s| s.typ(outers, inner, params))
982                    .collect();
983                CoercibleColumnType::Record(fields)
984            }
985            _ => CoercibleColumnType::Uncoerced,
986        }
987    }
988}
989
990/// A column type-like object whose underlying scalar type-like object can be
991/// ascertained.
992///
993/// Abstracts over `ColumnType` and `CoercibleColumnType`.
994pub trait AbstractColumnType {
995    type AbstractScalarType;
996
997    /// Converts the column type-like object into its inner scalar type-like
998    /// object.
999    fn scalar_type(self) -> Self::AbstractScalarType;
1000}
1001
1002impl AbstractColumnType for ColumnType {
1003    type AbstractScalarType = ScalarType;
1004
1005    fn scalar_type(self) -> Self::AbstractScalarType {
1006        self.scalar_type
1007    }
1008}
1009
1010impl AbstractColumnType for CoercibleColumnType {
1011    type AbstractScalarType = CoercibleScalarType;
1012
1013    fn scalar_type(self) -> Self::AbstractScalarType {
1014        match self {
1015            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1016            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1017            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1018        }
1019    }
1020}
1021
1022impl From<HirScalarExpr> for CoercibleScalarExpr {
1023    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1024        CoercibleScalarExpr::Coerced(expr)
1025    }
1026}
1027
1028/// A leveled column reference.
1029///
1030/// In the course of decorrelation, multiple levels of nested subqueries are
1031/// traversed, and references to columns may correspond to different levels
1032/// of containing outer subqueries.
1033///
1034/// A `ColumnRef` allows expressions to refer to columns while being clear
1035/// about which level the column references without manually performing the
1036/// bookkeeping tracking their actual column locations.
1037///
1038/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1039/// from the reference, using `column` as a unique identifier in that subquery level.
1040/// A `level` of zero corresponds to the current scope, and levels increase to
1041/// indicate subqueries further "outwards".
1042#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1043pub struct ColumnRef {
1044    // scope level, where 0 is the current scope and 1+ are outer scopes.
1045    pub level: usize,
1046    // level-local column identifier used.
1047    pub column: usize,
1048}
1049
1050#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1051pub enum JoinKind {
1052    Inner,
1053    LeftOuter,
1054    RightOuter,
1055    FullOuter,
1056}
1057
1058impl fmt::Display for JoinKind {
1059    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1060        write!(
1061            f,
1062            "{}",
1063            match self {
1064                JoinKind::Inner => "Inner",
1065                JoinKind::LeftOuter => "LeftOuter",
1066                JoinKind::RightOuter => "RightOuter",
1067                JoinKind::FullOuter => "FullOuter",
1068            }
1069        )
1070    }
1071}
1072
1073impl JoinKind {
1074    pub fn can_be_correlated(&self) -> bool {
1075        match self {
1076            JoinKind::Inner | JoinKind::LeftOuter => true,
1077            JoinKind::RightOuter | JoinKind::FullOuter => false,
1078        }
1079    }
1080}
1081
1082#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1083pub struct AggregateExpr {
1084    pub func: AggregateFunc,
1085    pub expr: Box<HirScalarExpr>,
1086    pub distinct: bool,
1087}
1088
1089/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1090/// types may be different.
1091///
1092/// Specifically, the nullability of the aggregate columns is more common
1093/// here than in `expr`, as these aggregates may be applied over empty
1094/// result sets and should be null in those cases, whereas `expr` variants
1095/// only return null values when supplied nulls as input.
1096#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1097pub enum AggregateFunc {
1098    MaxNumeric,
1099    MaxInt16,
1100    MaxInt32,
1101    MaxInt64,
1102    MaxUInt16,
1103    MaxUInt32,
1104    MaxUInt64,
1105    MaxMzTimestamp,
1106    MaxFloat32,
1107    MaxFloat64,
1108    MaxBool,
1109    MaxString,
1110    MaxDate,
1111    MaxTimestamp,
1112    MaxTimestampTz,
1113    MaxInterval,
1114    MaxTime,
1115    MinNumeric,
1116    MinInt16,
1117    MinInt32,
1118    MinInt64,
1119    MinUInt16,
1120    MinUInt32,
1121    MinUInt64,
1122    MinMzTimestamp,
1123    MinFloat32,
1124    MinFloat64,
1125    MinBool,
1126    MinString,
1127    MinDate,
1128    MinTimestamp,
1129    MinTimestampTz,
1130    MinInterval,
1131    MinTime,
1132    SumInt16,
1133    SumInt32,
1134    SumInt64,
1135    SumUInt16,
1136    SumUInt32,
1137    SumUInt64,
1138    SumFloat32,
1139    SumFloat64,
1140    SumNumeric,
1141    Count,
1142    Any,
1143    All,
1144    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1145    /// into a JSON list. The other elements are columns used by `order_by`.
1146    ///
1147    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1148    /// layer, this function filters out `Datum::Null`, for consistency with
1149    /// the other aggregate functions.
1150    JsonbAgg {
1151        order_by: Vec<ColumnOrder>,
1152    },
1153    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1154    /// JSON map. The other elements are columns used by `order_by`.
1155    JsonbObjectAgg {
1156        order_by: Vec<ColumnOrder>,
1157    },
1158    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1159    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1160    /// elements are columns used by `order_by`.
1161    MapAgg {
1162        order_by: Vec<ColumnOrder>,
1163        value_type: ScalarType,
1164    },
1165    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1166    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1167    ArrayConcat {
1168        order_by: Vec<ColumnOrder>,
1169    },
1170    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1171    /// single `Datum::List`. The other elements are columns used by `order_by`.
1172    ListConcat {
1173        order_by: Vec<ColumnOrder>,
1174    },
1175    StringAgg {
1176        order_by: Vec<ColumnOrder>,
1177    },
1178    /// A bundle of fused window aggregations: its input is a record, whose each
1179    /// component will be the input to one of the `AggregateFunc`s.
1180    ///
1181    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1182    /// more specifically an `AggregateWindowExpr`.
1183    FusedWindowAgg {
1184        funcs: Vec<AggregateFunc>,
1185    },
1186    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1187    ///
1188    /// Useful for removing an expensive aggregation while maintaining the shape
1189    /// of a reduce operator.
1190    Dummy,
1191}
1192
1193impl AggregateFunc {
1194    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1195    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1196        match self {
1197            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1198            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1199            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1200            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1201            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1202            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1203            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1204            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1205            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1206            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1207            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1208            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1209            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1210            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1211            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1212            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1213            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1214            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1215            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1216            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1217            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1218            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1219            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1220            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1221            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1222            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1223            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1224            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1225            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1226            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1227            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1228            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1229            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1230            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1231            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1232            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1233            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1234            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1235            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1236            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1237            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1238            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1239            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1240            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1241            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1242            AggregateFunc::All => mz_expr::AggregateFunc::All,
1243            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1244            AggregateFunc::JsonbObjectAgg { order_by } => {
1245                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1246            }
1247            AggregateFunc::MapAgg {
1248                order_by,
1249                value_type,
1250            } => mz_expr::AggregateFunc::MapAgg {
1251                order_by,
1252                value_type,
1253            },
1254            AggregateFunc::ArrayConcat { order_by } => {
1255                mz_expr::AggregateFunc::ArrayConcat { order_by }
1256            }
1257            AggregateFunc::ListConcat { order_by } => {
1258                mz_expr::AggregateFunc::ListConcat { order_by }
1259            }
1260            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1261            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1262            // `AggregateWindowExpr::into_expr`.
1263            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1264                panic!("into_expr called on FusedWindowAgg")
1265            }
1266            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1267        }
1268    }
1269
1270    /// Returns a datum whose inclusion in the aggregation will not change its
1271    /// result.
1272    ///
1273    /// # Panics
1274    ///
1275    /// Panics if called on a `FusedWindowAgg`.
1276    pub fn identity_datum(&self) -> Datum<'static> {
1277        match self {
1278            AggregateFunc::Any => Datum::False,
1279            AggregateFunc::All => Datum::True,
1280            AggregateFunc::Dummy => Datum::Dummy,
1281            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1282            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1283            AggregateFunc::MaxNumeric
1284            | AggregateFunc::MaxInt16
1285            | AggregateFunc::MaxInt32
1286            | AggregateFunc::MaxInt64
1287            | AggregateFunc::MaxUInt16
1288            | AggregateFunc::MaxUInt32
1289            | AggregateFunc::MaxUInt64
1290            | AggregateFunc::MaxMzTimestamp
1291            | AggregateFunc::MaxFloat32
1292            | AggregateFunc::MaxFloat64
1293            | AggregateFunc::MaxBool
1294            | AggregateFunc::MaxString
1295            | AggregateFunc::MaxDate
1296            | AggregateFunc::MaxTimestamp
1297            | AggregateFunc::MaxTimestampTz
1298            | AggregateFunc::MaxInterval
1299            | AggregateFunc::MaxTime
1300            | AggregateFunc::MinNumeric
1301            | AggregateFunc::MinInt16
1302            | AggregateFunc::MinInt32
1303            | AggregateFunc::MinInt64
1304            | AggregateFunc::MinUInt16
1305            | AggregateFunc::MinUInt32
1306            | AggregateFunc::MinUInt64
1307            | AggregateFunc::MinMzTimestamp
1308            | AggregateFunc::MinFloat32
1309            | AggregateFunc::MinFloat64
1310            | AggregateFunc::MinBool
1311            | AggregateFunc::MinString
1312            | AggregateFunc::MinDate
1313            | AggregateFunc::MinTimestamp
1314            | AggregateFunc::MinTimestampTz
1315            | AggregateFunc::MinInterval
1316            | AggregateFunc::MinTime
1317            | AggregateFunc::SumInt16
1318            | AggregateFunc::SumInt32
1319            | AggregateFunc::SumInt64
1320            | AggregateFunc::SumUInt16
1321            | AggregateFunc::SumUInt32
1322            | AggregateFunc::SumUInt64
1323            | AggregateFunc::SumFloat32
1324            | AggregateFunc::SumFloat64
1325            | AggregateFunc::SumNumeric
1326            | AggregateFunc::Count
1327            | AggregateFunc::JsonbAgg { .. }
1328            | AggregateFunc::JsonbObjectAgg { .. }
1329            | AggregateFunc::MapAgg { .. }
1330            | AggregateFunc::StringAgg { .. } => Datum::Null,
1331            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1332                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1333                // in HIR planning, because it is introduced only during HIR transformation.
1334                //
1335                // The implementation could be something like the following, except that we need to
1336                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1337                // ```
1338                // let temp_storage = RowArena::new();
1339                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1340                // ```
1341                panic!("FusedWindowAgg doesn't have an identity_datum")
1342            }
1343        }
1344    }
1345
1346    /// The output column type for the result of an aggregation.
1347    ///
1348    /// The output column type also contains nullability information, which
1349    /// is (without further information) true for aggregations that are not
1350    /// counts.
1351    pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
1352        let scalar_type = match self {
1353            AggregateFunc::Count => ScalarType::Int64,
1354            AggregateFunc::Any => ScalarType::Bool,
1355            AggregateFunc::All => ScalarType::Bool,
1356            AggregateFunc::JsonbAgg { .. } => ScalarType::Jsonb,
1357            AggregateFunc::JsonbObjectAgg { .. } => ScalarType::Jsonb,
1358            AggregateFunc::StringAgg { .. } => ScalarType::String,
1359            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => ScalarType::Int64,
1360            AggregateFunc::SumInt64 => ScalarType::Numeric {
1361                max_scale: Some(NumericMaxScale::ZERO),
1362            },
1363            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => ScalarType::UInt64,
1364            AggregateFunc::SumUInt64 => ScalarType::Numeric {
1365                max_scale: Some(NumericMaxScale::ZERO),
1366            },
1367            AggregateFunc::MapAgg { value_type, .. } => ScalarType::Map {
1368                value_type: Box::new(value_type.clone()),
1369                custom_id: None,
1370            },
1371            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1372                match input_type.scalar_type {
1373                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1374                    ScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1375                    _ => unreachable!(),
1376                }
1377            }
1378            AggregateFunc::MaxNumeric
1379            | AggregateFunc::MaxInt16
1380            | AggregateFunc::MaxInt32
1381            | AggregateFunc::MaxInt64
1382            | AggregateFunc::MaxUInt16
1383            | AggregateFunc::MaxUInt32
1384            | AggregateFunc::MaxUInt64
1385            | AggregateFunc::MaxMzTimestamp
1386            | AggregateFunc::MaxFloat32
1387            | AggregateFunc::MaxFloat64
1388            | AggregateFunc::MaxBool
1389            | AggregateFunc::MaxString
1390            | AggregateFunc::MaxDate
1391            | AggregateFunc::MaxTimestamp
1392            | AggregateFunc::MaxTimestampTz
1393            | AggregateFunc::MaxInterval
1394            | AggregateFunc::MaxTime
1395            | AggregateFunc::MinNumeric
1396            | AggregateFunc::MinInt16
1397            | AggregateFunc::MinInt32
1398            | AggregateFunc::MinInt64
1399            | AggregateFunc::MinUInt16
1400            | AggregateFunc::MinUInt32
1401            | AggregateFunc::MinUInt64
1402            | AggregateFunc::MinMzTimestamp
1403            | AggregateFunc::MinFloat32
1404            | AggregateFunc::MinFloat64
1405            | AggregateFunc::MinBool
1406            | AggregateFunc::MinString
1407            | AggregateFunc::MinDate
1408            | AggregateFunc::MinTimestamp
1409            | AggregateFunc::MinTimestampTz
1410            | AggregateFunc::MinInterval
1411            | AggregateFunc::MinTime
1412            | AggregateFunc::SumFloat32
1413            | AggregateFunc::SumFloat64
1414            | AggregateFunc::SumNumeric
1415            | AggregateFunc::Dummy => input_type.scalar_type,
1416            AggregateFunc::FusedWindowAgg { funcs } => {
1417                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1418                ScalarType::Record {
1419                    fields: funcs
1420                        .iter()
1421                        .zip_eq(input_types)
1422                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1423                        .collect(),
1424                    custom_id: None,
1425                }
1426            }
1427        };
1428        // max/min/sum return null on empty sets
1429        let nullable = !matches!(self, AggregateFunc::Count);
1430        scalar_type.nullable(nullable)
1431    }
1432
1433    pub fn is_order_sensitive(&self) -> bool {
1434        use AggregateFunc::*;
1435        matches!(
1436            self,
1437            JsonbAgg { .. }
1438                | JsonbObjectAgg { .. }
1439                | MapAgg { .. }
1440                | ArrayConcat { .. }
1441                | ListConcat { .. }
1442                | StringAgg { .. }
1443        )
1444    }
1445}
1446
1447impl HirRelationExpr {
1448    pub fn typ(
1449        &self,
1450        outers: &[RelationType],
1451        params: &BTreeMap<usize, ScalarType>,
1452    ) -> RelationType {
1453        stack::maybe_grow(|| match self {
1454            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1455            HirRelationExpr::Get { typ, .. } => typ.clone(),
1456            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1457            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1458            HirRelationExpr::Project { input, outputs } => {
1459                let input_typ = input.typ(outers, params);
1460                RelationType::new(
1461                    outputs
1462                        .iter()
1463                        .map(|&i| input_typ.column_types[i].clone())
1464                        .collect(),
1465                )
1466            }
1467            HirRelationExpr::Map { input, scalars } => {
1468                let mut typ = input.typ(outers, params);
1469                for scalar in scalars {
1470                    typ.column_types.push(scalar.typ(outers, &typ, params));
1471                }
1472                typ
1473            }
1474            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1475            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1476                input.typ(outers, params)
1477            }
1478            HirRelationExpr::Join {
1479                left, right, kind, ..
1480            } => {
1481                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1482                let right_nullable =
1483                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1484                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1485                    let nullable = t.nullable || left_nullable;
1486                    t.nullable(nullable)
1487                });
1488                let mut outers = outers.to_vec();
1489                outers.insert(0, RelationType::new(lt.clone().collect()));
1490                let rt = right
1491                    .typ(&outers, params)
1492                    .column_types
1493                    .into_iter()
1494                    .map(|t| {
1495                        let nullable = t.nullable || right_nullable;
1496                        t.nullable(nullable)
1497                    });
1498                RelationType::new(lt.chain(rt).collect())
1499            }
1500            HirRelationExpr::Reduce {
1501                input,
1502                group_key,
1503                aggregates,
1504                expected_group_size: _,
1505            } => {
1506                let input_typ = input.typ(outers, params);
1507                let mut column_types = group_key
1508                    .iter()
1509                    .map(|&i| input_typ.column_types[i].clone())
1510                    .collect::<Vec<_>>();
1511                for agg in aggregates {
1512                    column_types.push(agg.typ(outers, &input_typ, params));
1513                }
1514                // TODO(frank): add primary key information.
1515                RelationType::new(column_types)
1516            }
1517            // TODO(frank): check for removal; add primary key information.
1518            HirRelationExpr::Distinct { input }
1519            | HirRelationExpr::Negate { input }
1520            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1521            HirRelationExpr::Union { base, inputs } => {
1522                let mut base_cols = base.typ(outers, params).column_types;
1523                for input in inputs {
1524                    for (base_col, col) in base_cols
1525                        .iter_mut()
1526                        .zip_eq(input.typ(outers, params).column_types)
1527                    {
1528                        *base_col = base_col.union(&col).unwrap();
1529                    }
1530                }
1531                RelationType::new(base_cols)
1532            }
1533        })
1534    }
1535
1536    pub fn arity(&self) -> usize {
1537        match self {
1538            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1539            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1540            HirRelationExpr::Let { body, .. } => body.arity(),
1541            HirRelationExpr::LetRec { body, .. } => body.arity(),
1542            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1543            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1544            HirRelationExpr::CallTable { func, .. } => func.output_arity(),
1545            HirRelationExpr::Filter { input, .. }
1546            | HirRelationExpr::TopK { input, .. }
1547            | HirRelationExpr::Distinct { input }
1548            | HirRelationExpr::Negate { input }
1549            | HirRelationExpr::Threshold { input } => input.arity(),
1550            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1551            HirRelationExpr::Union { base, .. } => base.arity(),
1552            HirRelationExpr::Reduce {
1553                group_key,
1554                aggregates,
1555                ..
1556            } => group_key.len() + aggregates.len(),
1557        }
1558    }
1559
1560    /// If self is a constant, return the value and the type, otherwise `None`.
1561    pub fn as_const(&self) -> Option<(&Vec<Row>, &RelationType)> {
1562        match self {
1563            Self::Constant { rows, typ } => Some((rows, typ)),
1564            _ => None,
1565        }
1566    }
1567
1568    /// Reports whether this expression contains a column reference to its
1569    /// direct parent scope.
1570    pub fn is_correlated(&self) -> bool {
1571        let mut correlated = false;
1572        #[allow(deprecated)]
1573        self.visit_columns(0, &mut |depth, col| {
1574            if col.level > depth && col.level - depth == 1 {
1575                correlated = true;
1576            }
1577        });
1578        correlated
1579    }
1580
1581    pub fn is_join_identity(&self) -> bool {
1582        match self {
1583            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1584            _ => false,
1585        }
1586    }
1587
1588    pub fn project(self, outputs: Vec<usize>) -> Self {
1589        if outputs.iter().copied().eq(0..self.arity()) {
1590            // The projection is trivial. Suppress it.
1591            self
1592        } else {
1593            HirRelationExpr::Project {
1594                input: Box::new(self),
1595                outputs,
1596            }
1597        }
1598    }
1599
1600    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1601        if scalars.is_empty() {
1602            // The map is trivial. Suppress it.
1603            self
1604        } else if let HirRelationExpr::Map {
1605            scalars: old_scalars,
1606            input: _,
1607        } = &mut self
1608        {
1609            // Map applied to a map. Fuse the maps.
1610            old_scalars.extend(scalars);
1611            self
1612        } else {
1613            HirRelationExpr::Map {
1614                input: Box::new(self),
1615                scalars,
1616            }
1617        }
1618    }
1619
1620    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1621        if let HirRelationExpr::Filter {
1622            input: _,
1623            predicates,
1624        } = &mut self
1625        {
1626            predicates.extend(preds);
1627            predicates.sort();
1628            predicates.dedup();
1629            self
1630        } else {
1631            preds.sort();
1632            preds.dedup();
1633            HirRelationExpr::Filter {
1634                input: Box::new(self),
1635                predicates: preds,
1636            }
1637        }
1638    }
1639
1640    pub fn reduce(
1641        self,
1642        group_key: Vec<usize>,
1643        aggregates: Vec<AggregateExpr>,
1644        expected_group_size: Option<u64>,
1645    ) -> Self {
1646        HirRelationExpr::Reduce {
1647            input: Box::new(self),
1648            group_key,
1649            aggregates,
1650            expected_group_size,
1651        }
1652    }
1653
1654    pub fn top_k(
1655        self,
1656        group_key: Vec<usize>,
1657        order_key: Vec<ColumnOrder>,
1658        limit: Option<HirScalarExpr>,
1659        offset: usize,
1660        expected_group_size: Option<u64>,
1661    ) -> Self {
1662        HirRelationExpr::TopK {
1663            input: Box::new(self),
1664            group_key,
1665            order_key,
1666            limit,
1667            offset,
1668            expected_group_size,
1669        }
1670    }
1671
1672    pub fn negate(self) -> Self {
1673        if let HirRelationExpr::Negate { input } = self {
1674            *input
1675        } else {
1676            HirRelationExpr::Negate {
1677                input: Box::new(self),
1678            }
1679        }
1680    }
1681
1682    pub fn distinct(self) -> Self {
1683        if let HirRelationExpr::Distinct { .. } = self {
1684            self
1685        } else {
1686            HirRelationExpr::Distinct {
1687                input: Box::new(self),
1688            }
1689        }
1690    }
1691
1692    pub fn threshold(self) -> Self {
1693        if let HirRelationExpr::Threshold { .. } = self {
1694            self
1695        } else {
1696            HirRelationExpr::Threshold {
1697                input: Box::new(self),
1698            }
1699        }
1700    }
1701
1702    pub fn union(self, other: Self) -> Self {
1703        let mut terms = Vec::new();
1704        if let HirRelationExpr::Union { base, inputs } = self {
1705            terms.push(*base);
1706            terms.extend(inputs);
1707        } else {
1708            terms.push(self);
1709        }
1710        if let HirRelationExpr::Union { base, inputs } = other {
1711            terms.push(*base);
1712            terms.extend(inputs);
1713        } else {
1714            terms.push(other);
1715        }
1716        HirRelationExpr::Union {
1717            base: Box::new(terms.remove(0)),
1718            inputs: terms,
1719        }
1720    }
1721
1722    pub fn exists(self) -> HirScalarExpr {
1723        HirScalarExpr::Exists(Box::new(self))
1724    }
1725
1726    pub fn select(self) -> HirScalarExpr {
1727        HirScalarExpr::Select(Box::new(self))
1728    }
1729
1730    pub fn join(
1731        self,
1732        mut right: HirRelationExpr,
1733        on: HirScalarExpr,
1734        kind: JoinKind,
1735    ) -> HirRelationExpr {
1736        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1737        {
1738            // The join can be elided, but we need to adjust column references
1739            // on the right-hand side to account for the removal of the scope
1740            // introduced by the join.
1741            #[allow(deprecated)]
1742            right.visit_columns_mut(0, &mut |depth, col| {
1743                if col.level > depth {
1744                    col.level -= 1;
1745                }
1746            });
1747            right
1748        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1749            self
1750        } else {
1751            HirRelationExpr::Join {
1752                left: Box::new(self),
1753                right: Box::new(right),
1754                on,
1755                kind,
1756            }
1757        }
1758    }
1759
1760    pub fn take(&mut self) -> HirRelationExpr {
1761        mem::replace(
1762            self,
1763            HirRelationExpr::constant(vec![], RelationType::new(Vec::new())),
1764        )
1765    }
1766
1767    #[deprecated = "Use `Visit::visit_post`."]
1768    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1769    where
1770        F: FnMut(&'a Self, usize),
1771    {
1772        #[allow(deprecated)]
1773        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1774                                                 depth: usize|
1775         -> Result<(), ()> {
1776            f(e, depth);
1777            Ok(())
1778        });
1779    }
1780
1781    #[deprecated = "Use `Visit::try_visit_post`."]
1782    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1783    where
1784        F: FnMut(&'a Self, usize) -> Result<(), E>,
1785    {
1786        #[allow(deprecated)]
1787        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1788            e.visit_fallible(depth, f)
1789        })?;
1790        f(self, depth)
1791    }
1792
1793    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1794    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1795    where
1796        F: FnMut(&'a Self, usize) -> Result<(), E>,
1797    {
1798        match self {
1799            HirRelationExpr::Constant { .. }
1800            | HirRelationExpr::Get { .. }
1801            | HirRelationExpr::CallTable { .. } => (),
1802            HirRelationExpr::Let { body, value, .. } => {
1803                f(value, depth)?;
1804                f(body, depth)?;
1805            }
1806            HirRelationExpr::LetRec {
1807                limit: _,
1808                bindings,
1809                body,
1810            } => {
1811                for (_, _, value, _) in bindings.iter() {
1812                    f(value, depth)?;
1813                }
1814                f(body, depth)?;
1815            }
1816            HirRelationExpr::Project { input, .. } => {
1817                f(input, depth)?;
1818            }
1819            HirRelationExpr::Map { input, .. } => {
1820                f(input, depth)?;
1821            }
1822            HirRelationExpr::Filter { input, .. } => {
1823                f(input, depth)?;
1824            }
1825            HirRelationExpr::Join { left, right, .. } => {
1826                f(left, depth)?;
1827                f(right, depth + 1)?;
1828            }
1829            HirRelationExpr::Reduce { input, .. } => {
1830                f(input, depth)?;
1831            }
1832            HirRelationExpr::Distinct { input } => {
1833                f(input, depth)?;
1834            }
1835            HirRelationExpr::TopK { input, .. } => {
1836                f(input, depth)?;
1837            }
1838            HirRelationExpr::Negate { input } => {
1839                f(input, depth)?;
1840            }
1841            HirRelationExpr::Threshold { input } => {
1842                f(input, depth)?;
1843            }
1844            HirRelationExpr::Union { base, inputs } => {
1845                f(base, depth)?;
1846                for input in inputs {
1847                    f(input, depth)?;
1848                }
1849            }
1850        }
1851        Ok(())
1852    }
1853
1854    #[deprecated = "Use `Visit::visit_mut_post` instead."]
1855    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1856    where
1857        F: FnMut(&mut Self, usize),
1858    {
1859        #[allow(deprecated)]
1860        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1861                                                     depth: usize|
1862         -> Result<(), ()> {
1863            f(e, depth);
1864            Ok(())
1865        });
1866    }
1867
1868    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1869    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1870    where
1871        F: FnMut(&mut Self, usize) -> Result<(), E>,
1872    {
1873        #[allow(deprecated)]
1874        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1875            e.visit_mut_fallible(depth, f)
1876        })?;
1877        f(self, depth)
1878    }
1879
1880    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1881    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1882    where
1883        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1884    {
1885        match self {
1886            HirRelationExpr::Constant { .. }
1887            | HirRelationExpr::Get { .. }
1888            | HirRelationExpr::CallTable { .. } => (),
1889            HirRelationExpr::Let { body, value, .. } => {
1890                f(value, depth)?;
1891                f(body, depth)?;
1892            }
1893            HirRelationExpr::LetRec {
1894                limit: _,
1895                bindings,
1896                body,
1897            } => {
1898                for (_, _, value, _) in bindings.iter_mut() {
1899                    f(value, depth)?;
1900                }
1901                f(body, depth)?;
1902            }
1903            HirRelationExpr::Project { input, .. } => {
1904                f(input, depth)?;
1905            }
1906            HirRelationExpr::Map { input, .. } => {
1907                f(input, depth)?;
1908            }
1909            HirRelationExpr::Filter { input, .. } => {
1910                f(input, depth)?;
1911            }
1912            HirRelationExpr::Join { left, right, .. } => {
1913                f(left, depth)?;
1914                f(right, depth + 1)?;
1915            }
1916            HirRelationExpr::Reduce { input, .. } => {
1917                f(input, depth)?;
1918            }
1919            HirRelationExpr::Distinct { input } => {
1920                f(input, depth)?;
1921            }
1922            HirRelationExpr::TopK { input, .. } => {
1923                f(input, depth)?;
1924            }
1925            HirRelationExpr::Negate { input } => {
1926                f(input, depth)?;
1927            }
1928            HirRelationExpr::Threshold { input } => {
1929                f(input, depth)?;
1930            }
1931            HirRelationExpr::Union { base, inputs } => {
1932                f(base, depth)?;
1933                for input in inputs {
1934                    f(input, depth)?;
1935                }
1936            }
1937        }
1938        Ok(())
1939    }
1940
1941    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1942    /// Visits all scalar expressions within the sub-tree of the given relation.
1943    ///
1944    /// The `depth` argument should indicate the subquery nesting depth of the expression,
1945    /// which will be incremented when entering the RHS of a join or a subquery and
1946    /// presented to the supplied function `f`.
1947    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1948    where
1949        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1950    {
1951        #[allow(deprecated)]
1952        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1953                                         depth: usize|
1954         -> Result<(), E> {
1955            match e {
1956                HirRelationExpr::Join { on, .. } => {
1957                    f(on, depth)?;
1958                }
1959                HirRelationExpr::Map { scalars, .. } => {
1960                    for scalar in scalars {
1961                        f(scalar, depth)?;
1962                    }
1963                }
1964                HirRelationExpr::CallTable { exprs, .. } => {
1965                    for expr in exprs {
1966                        f(expr, depth)?;
1967                    }
1968                }
1969                HirRelationExpr::Filter { predicates, .. } => {
1970                    for predicate in predicates {
1971                        f(predicate, depth)?;
1972                    }
1973                }
1974                HirRelationExpr::Reduce { aggregates, .. } => {
1975                    for aggregate in aggregates {
1976                        f(&aggregate.expr, depth)?;
1977                    }
1978                }
1979                HirRelationExpr::TopK { limit, .. } => {
1980                    if let Some(limit) = limit {
1981                        f(limit, depth)?;
1982                    }
1983                }
1984                HirRelationExpr::Union { .. }
1985                | HirRelationExpr::Let { .. }
1986                | HirRelationExpr::LetRec { .. }
1987                | HirRelationExpr::Project { .. }
1988                | HirRelationExpr::Distinct { .. }
1989                | HirRelationExpr::Negate { .. }
1990                | HirRelationExpr::Threshold { .. }
1991                | HirRelationExpr::Constant { .. }
1992                | HirRelationExpr::Get { .. } => (),
1993            }
1994            Ok(())
1995        })
1996    }
1997
1998    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1999    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2000    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2001    where
2002        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2003    {
2004        #[allow(deprecated)]
2005        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2006                                             depth: usize|
2007         -> Result<(), E> {
2008            match e {
2009                HirRelationExpr::Join { on, .. } => {
2010                    f(on, depth)?;
2011                }
2012                HirRelationExpr::Map { scalars, .. } => {
2013                    for scalar in scalars.iter_mut() {
2014                        f(scalar, depth)?;
2015                    }
2016                }
2017                HirRelationExpr::CallTable { exprs, .. } => {
2018                    for expr in exprs.iter_mut() {
2019                        f(expr, depth)?;
2020                    }
2021                }
2022                HirRelationExpr::Filter { predicates, .. } => {
2023                    for predicate in predicates.iter_mut() {
2024                        f(predicate, depth)?;
2025                    }
2026                }
2027                HirRelationExpr::Reduce { aggregates, .. } => {
2028                    for aggregate in aggregates.iter_mut() {
2029                        f(&mut aggregate.expr, depth)?;
2030                    }
2031                }
2032                HirRelationExpr::TopK { limit, .. } => {
2033                    if let Some(limit) = limit {
2034                        f(limit, depth)?;
2035                    }
2036                }
2037                HirRelationExpr::Union { .. }
2038                | HirRelationExpr::Let { .. }
2039                | HirRelationExpr::LetRec { .. }
2040                | HirRelationExpr::Project { .. }
2041                | HirRelationExpr::Distinct { .. }
2042                | HirRelationExpr::Negate { .. }
2043                | HirRelationExpr::Threshold { .. }
2044                | HirRelationExpr::Constant { .. }
2045                | HirRelationExpr::Get { .. } => (),
2046            }
2047            Ok(())
2048        })
2049    }
2050
2051    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2052    /// Visits the column references in this relation expression.
2053    ///
2054    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2055    /// which will be incremented when entering the RHS of a join or a subquery and
2056    /// presented to the supplied function `f`.
2057    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2058    where
2059        F: FnMut(usize, &ColumnRef),
2060    {
2061        #[allow(deprecated)]
2062        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2063                                                           depth: usize|
2064         -> Result<(), ()> {
2065            e.visit_columns(depth, f);
2066            Ok(())
2067        });
2068    }
2069
2070    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2071    /// Like `visit_columns`, but permits mutating the column references.
2072    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2073    where
2074        F: FnMut(usize, &mut ColumnRef),
2075    {
2076        #[allow(deprecated)]
2077        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2078                                                               depth: usize|
2079         -> Result<(), ()> {
2080            e.visit_columns_mut(depth, f);
2081            Ok(())
2082        });
2083    }
2084
2085    /// Replaces any parameter references in the expression with the
2086    /// corresponding datum from `params`.
2087    pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2088        #[allow(deprecated)]
2089        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2090            e.bind_parameters(params)
2091        })
2092    }
2093
2094    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2095    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2096        #[allow(deprecated)]
2097        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2098                                                               depth: usize|
2099         -> Result<(), ()> {
2100            e.splice_parameters(params, depth);
2101            Ok(())
2102        });
2103    }
2104
2105    /// Constructs a constant collection from specific rows and schema.
2106    pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
2107        let rows = rows
2108            .into_iter()
2109            .map(move |datums| Row::pack_slice(&datums))
2110            .collect();
2111        HirRelationExpr::Constant { rows, typ }
2112    }
2113
2114    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2115    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2116    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2117    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2118    /// finishing into a trivial finishing.
2119    pub fn finish_maintained(
2120        &mut self,
2121        finishing: &mut RowSetFinishing<HirScalarExpr>,
2122        group_size_hints: GroupSizeHints,
2123    ) {
2124        if !finishing.is_trivial(self.arity()) {
2125            let old_finishing =
2126                mem::replace(finishing, RowSetFinishing::trivial(finishing.project.len()));
2127            *self = HirRelationExpr::top_k(
2128                std::mem::replace(
2129                    self,
2130                    HirRelationExpr::Constant {
2131                        rows: vec![],
2132                        typ: RelationType::new(Vec::new()),
2133                    },
2134                ),
2135                vec![],
2136                old_finishing.order_by,
2137                old_finishing.limit,
2138                old_finishing.offset,
2139                group_size_hints.limit_input_group_size,
2140            )
2141            .project(old_finishing.project);
2142        }
2143    }
2144
2145    /// The HirRelationExpr is considered potentially expensive if and only if
2146    /// at least one of the following conditions is true:
2147    ///
2148    ///  - It contains at least one CallTable or a Reduce operator.
2149    ///  - It contains at least one HirScalarExpr with a function call.
2150    ///
2151    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2152    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2153    pub fn could_run_expensive_function(&self) -> bool {
2154        let mut result = false;
2155        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2156            use HirRelationExpr::*;
2157            use HirScalarExpr::*;
2158
2159            self.visit_children(|scalar: &HirScalarExpr| {
2160                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2161                    result |= match scalar {
2162                        Column(_)
2163                        | Literal(_, _)
2164                        | CallUnmaterializable(_)
2165                        | If { .. }
2166                        | Parameter(..)
2167                        | Select(..)
2168                        | Exists(..) => false,
2169                        // Function calls are considered expensive
2170                        CallUnary { .. }
2171                        | CallBinary { .. }
2172                        | CallVariadic { .. }
2173                        | Windowing(..) => true,
2174                    };
2175                }) {
2176                    // Conservatively set `true` on RecursionLimitError.
2177                    result = true;
2178                }
2179            });
2180
2181            // CallTable has a table function; Reduce has an aggregate function.
2182            // Other constructs use MirScalarExpr to run a function
2183            result |= matches!(e, CallTable { .. } | Reduce { .. });
2184        }) {
2185            // Conservatively set `true` on RecursionLimitError.
2186            result = true;
2187        }
2188
2189        result
2190    }
2191
2192    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2193    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2194        let mut contains = false;
2195        self.visit_post(&mut |expr| {
2196            expr.visit_children(|expr: &HirScalarExpr| {
2197                contains = contains || expr.contains_temporal()
2198            })
2199        })?;
2200        Ok(contains)
2201    }
2202}
2203
2204impl CollectionPlan for HirRelationExpr {
2205    // !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2206    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2207    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2208        if let Self::Get {
2209            id: Id::Global(id), ..
2210        } = self
2211        {
2212            out.insert(*id);
2213        }
2214        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2215    }
2216}
2217
2218impl VisitChildren<Self> for HirRelationExpr {
2219    fn visit_children<F>(&self, mut f: F)
2220    where
2221        F: FnMut(&Self),
2222    {
2223        // subqueries of type HirRelationExpr might be wrapped in
2224        // Exists or Select variants within HirScalarExpr trees
2225        // attached at the current node, and we want to visit them as well
2226        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2227            #[allow(deprecated)]
2228            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2229                HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_ref()),
2230                _ => (),
2231            });
2232        });
2233
2234        use HirRelationExpr::*;
2235        match self {
2236            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2237            Let {
2238                name: _,
2239                id: _,
2240                value,
2241                body,
2242            } => {
2243                f(value);
2244                f(body);
2245            }
2246            LetRec {
2247                limit: _,
2248                bindings,
2249                body,
2250            } => {
2251                for (_, _, value, _) in bindings.iter() {
2252                    f(value);
2253                }
2254                f(body);
2255            }
2256            Project { input, outputs: _ } => f(input),
2257            Map { input, scalars: _ } => {
2258                f(input);
2259            }
2260            CallTable { func: _, exprs: _ } => (),
2261            Filter {
2262                input,
2263                predicates: _,
2264            } => {
2265                f(input);
2266            }
2267            Join {
2268                left,
2269                right,
2270                on: _,
2271                kind: _,
2272            } => {
2273                f(left);
2274                f(right);
2275            }
2276            Reduce {
2277                input,
2278                group_key: _,
2279                aggregates: _,
2280                expected_group_size: _,
2281            } => {
2282                f(input);
2283            }
2284            Distinct { input }
2285            | TopK {
2286                input,
2287                group_key: _,
2288                order_key: _,
2289                limit: _,
2290                offset: _,
2291                expected_group_size: _,
2292            }
2293            | Negate { input }
2294            | Threshold { input } => {
2295                f(input);
2296            }
2297            Union { base, inputs } => {
2298                f(base);
2299                for input in inputs {
2300                    f(input);
2301                }
2302            }
2303        }
2304    }
2305
2306    fn visit_mut_children<F>(&mut self, mut f: F)
2307    where
2308        F: FnMut(&mut Self),
2309    {
2310        // subqueries of type HirRelationExpr might be wrapped in
2311        // Exists or Select variants within HirScalarExpr trees
2312        // attached at the current node, and we want to visit them as well
2313        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2314            #[allow(deprecated)]
2315            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2316                HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_mut()),
2317                _ => (),
2318            });
2319        });
2320
2321        use HirRelationExpr::*;
2322        match self {
2323            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2324            Let {
2325                name: _,
2326                id: _,
2327                value,
2328                body,
2329            } => {
2330                f(value);
2331                f(body);
2332            }
2333            LetRec {
2334                limit: _,
2335                bindings,
2336                body,
2337            } => {
2338                for (_, _, value, _) in bindings.iter_mut() {
2339                    f(value);
2340                }
2341                f(body);
2342            }
2343            Project { input, outputs: _ } => f(input),
2344            Map { input, scalars: _ } => {
2345                f(input);
2346            }
2347            CallTable { func: _, exprs: _ } => (),
2348            Filter {
2349                input,
2350                predicates: _,
2351            } => {
2352                f(input);
2353            }
2354            Join {
2355                left,
2356                right,
2357                on: _,
2358                kind: _,
2359            } => {
2360                f(left);
2361                f(right);
2362            }
2363            Reduce {
2364                input,
2365                group_key: _,
2366                aggregates: _,
2367                expected_group_size: _,
2368            } => {
2369                f(input);
2370            }
2371            Distinct { input }
2372            | TopK {
2373                input,
2374                group_key: _,
2375                order_key: _,
2376                limit: _,
2377                offset: _,
2378                expected_group_size: _,
2379            }
2380            | Negate { input }
2381            | Threshold { input } => {
2382                f(input);
2383            }
2384            Union { base, inputs } => {
2385                f(base);
2386                for input in inputs {
2387                    f(input);
2388                }
2389            }
2390        }
2391    }
2392
2393    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2394    where
2395        F: FnMut(&Self) -> Result<(), E>,
2396        E: From<RecursionLimitError>,
2397    {
2398        // subqueries of type HirRelationExpr might be wrapped in
2399        // Exists or Select variants within HirScalarExpr trees
2400        // attached at the current node, and we want to visit them as well
2401        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2402            Visit::try_visit_post(expr, &mut |expr| match expr {
2403                HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_ref()),
2404                _ => Ok(()),
2405            })
2406        })?;
2407
2408        use HirRelationExpr::*;
2409        match self {
2410            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2411            Let {
2412                name: _,
2413                id: _,
2414                value,
2415                body,
2416            } => {
2417                f(value)?;
2418                f(body)?;
2419            }
2420            LetRec {
2421                limit: _,
2422                bindings,
2423                body,
2424            } => {
2425                for (_, _, value, _) in bindings.iter() {
2426                    f(value)?;
2427                }
2428                f(body)?;
2429            }
2430            Project { input, outputs: _ } => f(input)?,
2431            Map { input, scalars: _ } => {
2432                f(input)?;
2433            }
2434            CallTable { func: _, exprs: _ } => (),
2435            Filter {
2436                input,
2437                predicates: _,
2438            } => {
2439                f(input)?;
2440            }
2441            Join {
2442                left,
2443                right,
2444                on: _,
2445                kind: _,
2446            } => {
2447                f(left)?;
2448                f(right)?;
2449            }
2450            Reduce {
2451                input,
2452                group_key: _,
2453                aggregates: _,
2454                expected_group_size: _,
2455            } => {
2456                f(input)?;
2457            }
2458            Distinct { input }
2459            | TopK {
2460                input,
2461                group_key: _,
2462                order_key: _,
2463                limit: _,
2464                offset: _,
2465                expected_group_size: _,
2466            }
2467            | Negate { input }
2468            | Threshold { input } => {
2469                f(input)?;
2470            }
2471            Union { base, inputs } => {
2472                f(base)?;
2473                for input in inputs {
2474                    f(input)?;
2475                }
2476            }
2477        }
2478        Ok(())
2479    }
2480
2481    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2482    where
2483        F: FnMut(&mut Self) -> Result<(), E>,
2484        E: From<RecursionLimitError>,
2485    {
2486        // subqueries of type HirRelationExpr might be wrapped in
2487        // Exists or Select variants within HirScalarExpr trees
2488        // attached at the current node, and we want to visit them as well
2489        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2490            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2491                HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_mut()),
2492                _ => Ok(()),
2493            })
2494        })?;
2495
2496        use HirRelationExpr::*;
2497        match self {
2498            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2499            Let {
2500                name: _,
2501                id: _,
2502                value,
2503                body,
2504            } => {
2505                f(value)?;
2506                f(body)?;
2507            }
2508            LetRec {
2509                limit: _,
2510                bindings,
2511                body,
2512            } => {
2513                for (_, _, value, _) in bindings.iter_mut() {
2514                    f(value)?;
2515                }
2516                f(body)?;
2517            }
2518            Project { input, outputs: _ } => f(input)?,
2519            Map { input, scalars: _ } => {
2520                f(input)?;
2521            }
2522            CallTable { func: _, exprs: _ } => (),
2523            Filter {
2524                input,
2525                predicates: _,
2526            } => {
2527                f(input)?;
2528            }
2529            Join {
2530                left,
2531                right,
2532                on: _,
2533                kind: _,
2534            } => {
2535                f(left)?;
2536                f(right)?;
2537            }
2538            Reduce {
2539                input,
2540                group_key: _,
2541                aggregates: _,
2542                expected_group_size: _,
2543            } => {
2544                f(input)?;
2545            }
2546            Distinct { input }
2547            | TopK {
2548                input,
2549                group_key: _,
2550                order_key: _,
2551                limit: _,
2552                offset: _,
2553                expected_group_size: _,
2554            }
2555            | Negate { input }
2556            | Threshold { input } => {
2557                f(input)?;
2558            }
2559            Union { base, inputs } => {
2560                f(base)?;
2561                for input in inputs {
2562                    f(input)?;
2563                }
2564            }
2565        }
2566        Ok(())
2567    }
2568}
2569
2570impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2571    fn visit_children<F>(&self, mut f: F)
2572    where
2573        F: FnMut(&HirScalarExpr),
2574    {
2575        use HirRelationExpr::*;
2576        match self {
2577            Constant { rows: _, typ: _ }
2578            | Get { id: _, typ: _ }
2579            | Let {
2580                name: _,
2581                id: _,
2582                value: _,
2583                body: _,
2584            }
2585            | LetRec {
2586                limit: _,
2587                bindings: _,
2588                body: _,
2589            }
2590            | Project {
2591                input: _,
2592                outputs: _,
2593            } => (),
2594            Map { input: _, scalars } => {
2595                for scalar in scalars {
2596                    f(scalar);
2597                }
2598            }
2599            CallTable { func: _, exprs } => {
2600                for expr in exprs {
2601                    f(expr);
2602                }
2603            }
2604            Filter {
2605                input: _,
2606                predicates,
2607            } => {
2608                for predicate in predicates {
2609                    f(predicate);
2610                }
2611            }
2612            Join {
2613                left: _,
2614                right: _,
2615                on,
2616                kind: _,
2617            } => f(on),
2618            Reduce {
2619                input: _,
2620                group_key: _,
2621                aggregates,
2622                expected_group_size: _,
2623            } => {
2624                for aggregate in aggregates {
2625                    f(aggregate.expr.as_ref());
2626                }
2627            }
2628            TopK {
2629                input: _,
2630                group_key: _,
2631                order_key: _,
2632                limit,
2633                offset: _,
2634                expected_group_size: _,
2635            } => {
2636                if let Some(limit) = limit {
2637                    f(limit)
2638                }
2639            }
2640            Distinct { input: _ }
2641            | Negate { input: _ }
2642            | Threshold { input: _ }
2643            | Union { base: _, inputs: _ } => (),
2644        }
2645    }
2646
2647    fn visit_mut_children<F>(&mut self, mut f: F)
2648    where
2649        F: FnMut(&mut HirScalarExpr),
2650    {
2651        use HirRelationExpr::*;
2652        match self {
2653            Constant { rows: _, typ: _ }
2654            | Get { id: _, typ: _ }
2655            | Let {
2656                name: _,
2657                id: _,
2658                value: _,
2659                body: _,
2660            }
2661            | LetRec {
2662                limit: _,
2663                bindings: _,
2664                body: _,
2665            }
2666            | Project {
2667                input: _,
2668                outputs: _,
2669            } => (),
2670            Map { input: _, scalars } => {
2671                for scalar in scalars {
2672                    f(scalar);
2673                }
2674            }
2675            CallTable { func: _, exprs } => {
2676                for expr in exprs {
2677                    f(expr);
2678                }
2679            }
2680            Filter {
2681                input: _,
2682                predicates,
2683            } => {
2684                for predicate in predicates {
2685                    f(predicate);
2686                }
2687            }
2688            Join {
2689                left: _,
2690                right: _,
2691                on,
2692                kind: _,
2693            } => f(on),
2694            Reduce {
2695                input: _,
2696                group_key: _,
2697                aggregates,
2698                expected_group_size: _,
2699            } => {
2700                for aggregate in aggregates {
2701                    f(aggregate.expr.as_mut());
2702                }
2703            }
2704            Distinct { input: _ }
2705            | TopK {
2706                input: _,
2707                group_key: _,
2708                order_key: _,
2709                limit: _,
2710                offset: _,
2711                expected_group_size: _,
2712            }
2713            | Negate { input: _ }
2714            | Threshold { input: _ }
2715            | Union { base: _, inputs: _ } => (),
2716        }
2717    }
2718
2719    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2720    where
2721        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2722        E: From<RecursionLimitError>,
2723    {
2724        use HirRelationExpr::*;
2725        match self {
2726            Constant { rows: _, typ: _ }
2727            | Get { id: _, typ: _ }
2728            | Let {
2729                name: _,
2730                id: _,
2731                value: _,
2732                body: _,
2733            }
2734            | LetRec {
2735                limit: _,
2736                bindings: _,
2737                body: _,
2738            }
2739            | Project {
2740                input: _,
2741                outputs: _,
2742            } => (),
2743            Map { input: _, scalars } => {
2744                for scalar in scalars {
2745                    f(scalar)?;
2746                }
2747            }
2748            CallTable { func: _, exprs } => {
2749                for expr in exprs {
2750                    f(expr)?;
2751                }
2752            }
2753            Filter {
2754                input: _,
2755                predicates,
2756            } => {
2757                for predicate in predicates {
2758                    f(predicate)?;
2759                }
2760            }
2761            Join {
2762                left: _,
2763                right: _,
2764                on,
2765                kind: _,
2766            } => f(on)?,
2767            Reduce {
2768                input: _,
2769                group_key: _,
2770                aggregates,
2771                expected_group_size: _,
2772            } => {
2773                for aggregate in aggregates {
2774                    f(aggregate.expr.as_ref())?;
2775                }
2776            }
2777            Distinct { input: _ }
2778            | TopK {
2779                input: _,
2780                group_key: _,
2781                order_key: _,
2782                limit: _,
2783                offset: _,
2784                expected_group_size: _,
2785            }
2786            | Negate { input: _ }
2787            | Threshold { input: _ }
2788            | Union { base: _, inputs: _ } => (),
2789        }
2790        Ok(())
2791    }
2792
2793    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2794    where
2795        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2796        E: From<RecursionLimitError>,
2797    {
2798        use HirRelationExpr::*;
2799        match self {
2800            Constant { rows: _, typ: _ }
2801            | Get { id: _, typ: _ }
2802            | Let {
2803                name: _,
2804                id: _,
2805                value: _,
2806                body: _,
2807            }
2808            | LetRec {
2809                limit: _,
2810                bindings: _,
2811                body: _,
2812            }
2813            | Project {
2814                input: _,
2815                outputs: _,
2816            } => (),
2817            Map { input: _, scalars } => {
2818                for scalar in scalars {
2819                    f(scalar)?;
2820                }
2821            }
2822            CallTable { func: _, exprs } => {
2823                for expr in exprs {
2824                    f(expr)?;
2825                }
2826            }
2827            Filter {
2828                input: _,
2829                predicates,
2830            } => {
2831                for predicate in predicates {
2832                    f(predicate)?;
2833                }
2834            }
2835            Join {
2836                left: _,
2837                right: _,
2838                on,
2839                kind: _,
2840            } => f(on)?,
2841            Reduce {
2842                input: _,
2843                group_key: _,
2844                aggregates,
2845                expected_group_size: _,
2846            } => {
2847                for aggregate in aggregates {
2848                    f(aggregate.expr.as_mut())?;
2849                }
2850            }
2851            Distinct { input: _ }
2852            | TopK {
2853                input: _,
2854                group_key: _,
2855                order_key: _,
2856                limit: _,
2857                offset: _,
2858                expected_group_size: _,
2859            }
2860            | Negate { input: _ }
2861            | Threshold { input: _ }
2862            | Union { base: _, inputs: _ } => (),
2863        }
2864        Ok(())
2865    }
2866}
2867
2868impl HirScalarExpr {
2869    /// Replaces any parameter references in the expression with the
2870    /// corresponding datum in `params`.
2871    pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2872        #[allow(deprecated)]
2873        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
2874            if let HirScalarExpr::Parameter(n) = e {
2875                let datum = match params.datums.iter().nth(*n - 1) {
2876                    None => sql_bail!("there is no parameter ${}", n),
2877                    Some(datum) => datum,
2878                };
2879                let scalar_type = &params.types[*n - 1];
2880                let row = Row::pack([datum]);
2881                let column_type = scalar_type.clone().nullable(datum.is_null());
2882                *e = HirScalarExpr::Literal(row, column_type);
2883            }
2884            Ok(())
2885        })
2886    }
2887
2888    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
2889    /// replaced with the corresponding expression fragment from `params` rather
2890    /// than a datum.
2891    ///
2892    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
2893    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
2894    /// in `self` that refer to invalid indices of `params` will cause a panic.
2895    ///
2896    /// Column references in parameters will be corrected to account for the
2897    /// depth at which they are spliced.
2898    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2899        #[allow(deprecated)]
2900        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
2901                                                        e: &mut HirScalarExpr|
2902         -> Result<(), ()> {
2903            if let HirScalarExpr::Parameter(i) = e {
2904                *e = params[*i - 1].clone();
2905                // Correct any column references in the parameter expression for
2906                // its new depth.
2907                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
2908                    if col.level >= d {
2909                        col.level += depth
2910                    }
2911                });
2912            }
2913            Ok(())
2914        });
2915    }
2916
2917    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2918    pub fn contains_temporal(&self) -> bool {
2919        let mut contains = false;
2920        #[allow(deprecated)]
2921        self.visit_post_nolimit(&mut |e| {
2922            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2923                contains = true;
2924            }
2925        });
2926        contains
2927    }
2928
2929    /// Constructs a column reference in the current scope.
2930    pub fn column(index: usize) -> HirScalarExpr {
2931        HirScalarExpr::Column(ColumnRef {
2932            level: 0,
2933            column: index,
2934        })
2935    }
2936
2937    pub fn literal(datum: Datum, scalar_type: ScalarType) -> HirScalarExpr {
2938        let row = Row::pack([datum]);
2939        HirScalarExpr::Literal(row, scalar_type.nullable(datum.is_null()))
2940    }
2941
2942    pub fn literal_true() -> HirScalarExpr {
2943        HirScalarExpr::literal(Datum::True, ScalarType::Bool)
2944    }
2945
2946    pub fn literal_false() -> HirScalarExpr {
2947        HirScalarExpr::literal(Datum::False, ScalarType::Bool)
2948    }
2949
2950    pub fn literal_null(scalar_type: ScalarType) -> HirScalarExpr {
2951        HirScalarExpr::literal(Datum::Null, scalar_type)
2952    }
2953
2954    pub fn literal_1d_array(
2955        datums: Vec<Datum>,
2956        element_scalar_type: ScalarType,
2957    ) -> Result<HirScalarExpr, PlanError> {
2958        let scalar_type = match element_scalar_type {
2959            ScalarType::Array(_) => {
2960                sql_bail!("cannot build array from array type");
2961            }
2962            typ => ScalarType::Array(Box::new(typ)).nullable(false),
2963        };
2964
2965        let mut row = Row::default();
2966        row.packer()
2967            .try_push_array(
2968                &[ArrayDimension {
2969                    lower_bound: 1,
2970                    length: datums.len(),
2971                }],
2972                datums,
2973            )
2974            .expect("array constructed to be valid");
2975
2976        Ok(HirScalarExpr::Literal(row, scalar_type))
2977    }
2978
2979    pub fn as_literal(&self) -> Option<Datum> {
2980        if let HirScalarExpr::Literal(row, _column_type) = self {
2981            Some(row.unpack_first())
2982        } else {
2983            None
2984        }
2985    }
2986
2987    pub fn is_literal_true(&self) -> bool {
2988        Some(Datum::True) == self.as_literal()
2989    }
2990
2991    pub fn is_literal_false(&self) -> bool {
2992        Some(Datum::False) == self.as_literal()
2993    }
2994
2995    pub fn is_literal_null(&self) -> bool {
2996        Some(Datum::Null) == self.as_literal()
2997    }
2998
2999    /// Return true iff `self` consists only of constants, function calls, and
3000    /// if-else statements.
3001    pub fn is_constant(&self) -> bool {
3002        let mut worklist = vec![self];
3003        while let Some(expr) = worklist.pop() {
3004            match expr {
3005                Self::Literal(_, _) => {
3006                    // leaf node, do nothing
3007                }
3008                Self::CallUnary { expr, .. } => {
3009                    worklist.push(expr);
3010                }
3011                Self::CallBinary {
3012                    func: _,
3013                    expr1,
3014                    expr2,
3015                } => {
3016                    worklist.push(expr1);
3017                    worklist.push(expr2);
3018                }
3019                Self::CallVariadic { func: _, exprs } => {
3020                    worklist.extend(exprs.iter());
3021                }
3022                Self::If { cond, then, els } => {
3023                    worklist.push(cond);
3024                    worklist.push(then);
3025                    worklist.push(els);
3026                }
3027                _ => {
3028                    return false; // Any other node makes `self` non-constant.
3029                }
3030            }
3031        }
3032        true
3033    }
3034
3035    pub fn call_unary(self, func: UnaryFunc) -> Self {
3036        HirScalarExpr::CallUnary {
3037            func,
3038            expr: Box::new(self),
3039        }
3040    }
3041
3042    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
3043        HirScalarExpr::CallBinary {
3044            func,
3045            expr1: Box::new(self),
3046            expr2: Box::new(other),
3047        }
3048    }
3049
3050    pub fn or(self, other: Self) -> Self {
3051        HirScalarExpr::CallVariadic {
3052            func: VariadicFunc::Or,
3053            exprs: vec![self, other],
3054        }
3055    }
3056
3057    pub fn and(self, other: Self) -> Self {
3058        HirScalarExpr::CallVariadic {
3059            func: VariadicFunc::And,
3060            exprs: vec![self, other],
3061        }
3062    }
3063
3064    pub fn not(self) -> Self {
3065        self.call_unary(UnaryFunc::Not(func::Not))
3066    }
3067
3068    pub fn call_is_null(self) -> Self {
3069        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3070    }
3071
3072    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3073    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3074        match args.len() {
3075            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3076            1 => args.swap_remove(0),
3077            _ => HirScalarExpr::CallVariadic {
3078                func: VariadicFunc::And,
3079                exprs: args,
3080            },
3081        }
3082    }
3083
3084    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3085    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3086        match args.len() {
3087            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3088            1 => args.swap_remove(0),
3089            _ => HirScalarExpr::CallVariadic {
3090                func: VariadicFunc::Or,
3091                exprs: args,
3092            },
3093        }
3094    }
3095
3096    pub fn take(&mut self) -> Self {
3097        mem::replace(self, HirScalarExpr::literal_null(ScalarType::String))
3098    }
3099
3100    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3101    /// Visits the column references in this scalar expression.
3102    ///
3103    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3104    /// which will be incremented with each subquery entered and presented to the supplied
3105    /// function `f`.
3106    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3107    where
3108        F: FnMut(usize, &ColumnRef),
3109    {
3110        #[allow(deprecated)]
3111        let _ = self.visit_recursively(depth, &mut |depth: usize,
3112                                                    e: &HirScalarExpr|
3113         -> Result<(), ()> {
3114            if let HirScalarExpr::Column(col) = e {
3115                f(depth, col)
3116            }
3117            Ok(())
3118        });
3119    }
3120
3121    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3122    /// Like `visit_columns`, but permits mutating the column references.
3123    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3124    where
3125        F: FnMut(usize, &mut ColumnRef),
3126    {
3127        #[allow(deprecated)]
3128        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3129                                                        e: &mut HirScalarExpr|
3130         -> Result<(), ()> {
3131            if let HirScalarExpr::Column(col) = e {
3132                f(depth, col)
3133            }
3134            Ok(())
3135        });
3136    }
3137
3138    /// Visits those column references in this scalar expression that refer to the root
3139    /// level. These include column references that are at the root level, as well as column
3140    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3141    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3142    /// "root level" to be `self`'s level.)
3143    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3144    where
3145        F: FnMut(usize),
3146    {
3147        #[allow(deprecated)]
3148        let _ = self.visit_recursively(0, &mut |depth: usize,
3149                                                e: &HirScalarExpr|
3150         -> Result<(), ()> {
3151            if let HirScalarExpr::Column(col) = e {
3152                if col.level == depth {
3153                    f(col.column)
3154                }
3155            }
3156            Ok(())
3157        });
3158    }
3159
3160    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3161    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3162    where
3163        F: FnMut(&mut usize),
3164    {
3165        #[allow(deprecated)]
3166        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3167                                                    e: &mut HirScalarExpr|
3168         -> Result<(), ()> {
3169            if let HirScalarExpr::Column(col) = e {
3170                if col.level == depth {
3171                    f(&mut col.column)
3172                }
3173            }
3174            Ok(())
3175        });
3176    }
3177
3178    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3179    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3180    /// in them. It takes the current depth of the expression and increases it when
3181    /// entering a subquery.
3182    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3183    where
3184        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3185    {
3186        match self {
3187            HirScalarExpr::Literal(_, _)
3188            | HirScalarExpr::Parameter(_)
3189            | HirScalarExpr::CallUnmaterializable(_)
3190            | HirScalarExpr::Column(_) => (),
3191            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3192            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3193                expr1.visit_recursively(depth, f)?;
3194                expr2.visit_recursively(depth, f)?;
3195            }
3196            HirScalarExpr::CallVariadic { exprs, .. } => {
3197                for expr in exprs {
3198                    expr.visit_recursively(depth, f)?;
3199                }
3200            }
3201            HirScalarExpr::If { cond, then, els } => {
3202                cond.visit_recursively(depth, f)?;
3203                then.visit_recursively(depth, f)?;
3204                els.visit_recursively(depth, f)?;
3205            }
3206            HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => {
3207                #[allow(deprecated)]
3208                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3209                    e.visit_recursively(depth, f)
3210                })?;
3211            }
3212            HirScalarExpr::Windowing(expr) => {
3213                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3214            }
3215        }
3216        f(depth, self)
3217    }
3218
3219    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3220    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3221    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3222    where
3223        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3224    {
3225        match self {
3226            HirScalarExpr::Literal(_, _)
3227            | HirScalarExpr::Parameter(_)
3228            | HirScalarExpr::CallUnmaterializable(_)
3229            | HirScalarExpr::Column(_) => (),
3230            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3231            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3232                expr1.visit_recursively_mut(depth, f)?;
3233                expr2.visit_recursively_mut(depth, f)?;
3234            }
3235            HirScalarExpr::CallVariadic { exprs, .. } => {
3236                for expr in exprs {
3237                    expr.visit_recursively_mut(depth, f)?;
3238                }
3239            }
3240            HirScalarExpr::If { cond, then, els } => {
3241                cond.visit_recursively_mut(depth, f)?;
3242                then.visit_recursively_mut(depth, f)?;
3243                els.visit_recursively_mut(depth, f)?;
3244            }
3245            HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => {
3246                #[allow(deprecated)]
3247                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3248                    e.visit_recursively_mut(depth, f)
3249                })?;
3250            }
3251            HirScalarExpr::Windowing(expr) => {
3252                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3253            }
3254        }
3255        f(depth, self)
3256    }
3257
3258    fn simplify_to_literal(self) -> Option<Row> {
3259        let mut expr = self.lower_uncorrelated().ok()?;
3260        expr.reduce(&[]);
3261        match expr {
3262            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3263            _ => None,
3264        }
3265    }
3266
3267    /// Attempts to simplify this expression to a literal 64-bit integer.
3268    ///
3269    /// Returns `None` if this expression cannot be simplified, e.g. because it
3270    /// contains non-literal values.
3271    ///
3272    /// # Panics
3273    ///
3274    /// Panics if this expression does not have type [`ScalarType::Int64`].
3275    pub fn into_literal_int64(self) -> Option<i64> {
3276        self.simplify_to_literal().and_then(|row| {
3277            let datum = row.unpack_first();
3278            if datum.is_null() {
3279                None
3280            } else {
3281                Some(datum.unwrap_int64())
3282            }
3283        })
3284    }
3285
3286    /// Attempts to simplify this expression to a literal string.
3287    ///
3288    /// Returns `None` if this expression cannot be simplified, e.g. because it
3289    /// contains non-literal values.
3290    ///
3291    /// # Panics
3292    ///
3293    /// Panics if this expression does not have type [`ScalarType::String`].
3294    pub fn into_literal_string(self) -> Option<String> {
3295        self.simplify_to_literal().and_then(|row| {
3296            let datum = row.unpack_first();
3297            if datum.is_null() {
3298                None
3299            } else {
3300                Some(datum.unwrap_str().to_owned())
3301            }
3302        })
3303    }
3304
3305    /// Attempts to simplify this expression to a literal MzTimestamp.
3306    ///
3307    /// Returns `None` if this expression cannot be simplified, e.g. because it
3308    /// contains non-literal values.
3309    ///
3310    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3311    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3312    ///
3313    /// # Panics
3314    ///
3315    /// Panics if this expression does not have type [`ScalarType::MzTimestamp`].
3316    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3317        self.simplify_to_literal().and_then(|row| {
3318            let datum = row.unpack_first();
3319            if datum.is_null() {
3320                None
3321            } else {
3322                Some(datum.unwrap_mz_timestamp())
3323            }
3324        })
3325    }
3326}
3327
3328impl VisitChildren<Self> for HirScalarExpr {
3329    fn visit_children<F>(&self, mut f: F)
3330    where
3331        F: FnMut(&Self),
3332    {
3333        use HirScalarExpr::*;
3334        match self {
3335            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3336            CallUnary { expr, .. } => f(expr),
3337            CallBinary { expr1, expr2, .. } => {
3338                f(expr1);
3339                f(expr2);
3340            }
3341            CallVariadic { exprs, .. } => {
3342                for expr in exprs {
3343                    f(expr);
3344                }
3345            }
3346            If { cond, then, els } => {
3347                f(cond);
3348                f(then);
3349                f(els);
3350            }
3351            Exists(..) | Select(..) => (),
3352            Windowing(expr) => expr.visit_children(f),
3353        }
3354    }
3355
3356    fn visit_mut_children<F>(&mut self, mut f: F)
3357    where
3358        F: FnMut(&mut Self),
3359    {
3360        use HirScalarExpr::*;
3361        match self {
3362            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3363            CallUnary { expr, .. } => f(expr),
3364            CallBinary { expr1, expr2, .. } => {
3365                f(expr1);
3366                f(expr2);
3367            }
3368            CallVariadic { exprs, .. } => {
3369                for expr in exprs {
3370                    f(expr);
3371                }
3372            }
3373            If { cond, then, els } => {
3374                f(cond);
3375                f(then);
3376                f(els);
3377            }
3378            Exists(..) | Select(..) => (),
3379            Windowing(expr) => expr.visit_mut_children(f),
3380        }
3381    }
3382
3383    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3384    where
3385        F: FnMut(&Self) -> Result<(), E>,
3386        E: From<RecursionLimitError>,
3387    {
3388        use HirScalarExpr::*;
3389        match self {
3390            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3391            CallUnary { expr, .. } => f(expr)?,
3392            CallBinary { expr1, expr2, .. } => {
3393                f(expr1)?;
3394                f(expr2)?;
3395            }
3396            CallVariadic { exprs, .. } => {
3397                for expr in exprs {
3398                    f(expr)?;
3399                }
3400            }
3401            If { cond, then, els } => {
3402                f(cond)?;
3403                f(then)?;
3404                f(els)?;
3405            }
3406            Exists(..) | Select(..) => (),
3407            Windowing(expr) => expr.try_visit_children(f)?,
3408        }
3409        Ok(())
3410    }
3411
3412    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3413    where
3414        F: FnMut(&mut Self) -> Result<(), E>,
3415        E: From<RecursionLimitError>,
3416    {
3417        use HirScalarExpr::*;
3418        match self {
3419            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3420            CallUnary { expr, .. } => f(expr)?,
3421            CallBinary { expr1, expr2, .. } => {
3422                f(expr1)?;
3423                f(expr2)?;
3424            }
3425            CallVariadic { exprs, .. } => {
3426                for expr in exprs {
3427                    f(expr)?;
3428                }
3429            }
3430            If { cond, then, els } => {
3431                f(cond)?;
3432                f(then)?;
3433                f(els)?;
3434            }
3435            Exists(..) | Select(..) => (),
3436            Windowing(expr) => expr.try_visit_mut_children(f)?,
3437        }
3438        Ok(())
3439    }
3440}
3441
3442impl AbstractExpr for HirScalarExpr {
3443    type Type = ColumnType;
3444
3445    fn typ(
3446        &self,
3447        outers: &[RelationType],
3448        inner: &RelationType,
3449        params: &BTreeMap<usize, ScalarType>,
3450    ) -> Self::Type {
3451        stack::maybe_grow(|| match self {
3452            HirScalarExpr::Column(ColumnRef { level, column }) => {
3453                if *level == 0 {
3454                    inner.column_types[*column].clone()
3455                } else {
3456                    outers[*level - 1].column_types[*column].clone()
3457                }
3458            }
3459            HirScalarExpr::Parameter(n) => params[n].clone().nullable(true),
3460            HirScalarExpr::Literal(_, typ) => typ.clone(),
3461            HirScalarExpr::CallUnmaterializable(func) => func.output_type(),
3462            HirScalarExpr::CallUnary { expr, func } => {
3463                func.output_type(expr.typ(outers, inner, params))
3464            }
3465            HirScalarExpr::CallBinary { expr1, expr2, func } => func.output_type(
3466                expr1.typ(outers, inner, params),
3467                expr2.typ(outers, inner, params),
3468            ),
3469            HirScalarExpr::CallVariadic { exprs, func } => {
3470                func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect())
3471            }
3472            HirScalarExpr::If { cond: _, then, els } => {
3473                let then_type = then.typ(outers, inner, params);
3474                let else_type = els.typ(outers, inner, params);
3475                then_type.union(&else_type).unwrap()
3476            }
3477            HirScalarExpr::Exists(_) => ScalarType::Bool.nullable(true),
3478            HirScalarExpr::Select(expr) => {
3479                let mut outers = outers.to_vec();
3480                outers.insert(0, inner.clone());
3481                expr.typ(&outers, params)
3482                    .column_types
3483                    .into_element()
3484                    .nullable(true)
3485            }
3486            HirScalarExpr::Windowing(expr) => expr.func.typ(outers, inner, params),
3487        })
3488    }
3489}
3490
3491impl AggregateExpr {
3492    /// Replaces any parameter references in the expression with the
3493    /// corresponding datum from `parameters`.
3494    pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
3495        self.expr.bind_parameters(params)
3496    }
3497
3498    pub fn typ(
3499        &self,
3500        outers: &[RelationType],
3501        inner: &RelationType,
3502        params: &BTreeMap<usize, ScalarType>,
3503    ) -> ColumnType {
3504        self.func.output_type(self.expr.typ(outers, inner, params))
3505    }
3506
3507    /// Returns whether the expression is COUNT(*) or not.  Note that
3508    /// when we define the count builtin in sql::func, we convert
3509    /// COUNT(*) to COUNT(true), making it indistinguishable from
3510    /// literal COUNT(true), but we prefer to consider this as the
3511    /// former.
3512    ///
3513    /// (MIR has the same `is_count_asterisk`.)
3514    pub fn is_count_asterisk(&self) -> bool {
3515        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3516    }
3517}