mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50    type Relation = HirRelationExpr;
51    type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56        if *all {
57            let rhs = rhs.negate();
58            HirRelationExpr::union(lhs, rhs).threshold()
59        } else {
60            let lhs = lhs.distinct();
61            let rhs = rhs.distinct().negate();
62            HirRelationExpr::union(lhs, rhs).threshold()
63        }
64    }
65
66    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67        let mut result = None;
68
69        use HirRelationExpr::*;
70        if let Threshold { input } = expr {
71            if let Union { base: lhs, inputs } = input.as_ref() {
72                if let [rhs] = &inputs[..] {
73                    if let Negate { input: rhs } = rhs {
74                        match (lhs.as_ref(), rhs.as_ref()) {
75                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
76                                let all = false;
77                                let lhs = lhs.as_ref();
78                                let rhs = rhs.as_ref();
79                                result = Some(Except { all, lhs, rhs })
80                            }
81                            (lhs, rhs) => {
82                                let all = true;
83                                result = Some(Except { all, lhs, rhs })
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        result
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
97pub enum HirRelationExpr {
98    Constant {
99        rows: Vec<Row>,
100        typ: SqlRelationType,
101    },
102    Get {
103        id: mz_expr::Id,
104        typ: SqlRelationType,
105    },
106    /// Mutually recursive CTE
107    LetRec {
108        /// Maximum number of iterations to evaluate. If None, then there is no limit.
109        limit: Option<LetRecLimit>,
110        /// List of bindings all of which are in scope of each other.
111        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
112        /// Result of the AST node.
113        body: Box<HirRelationExpr>,
114    },
115    /// CTE
116    Let {
117        name: String,
118        /// The identifier to be used in `Get` variants to retrieve `value`.
119        id: mz_expr::LocalId,
120        /// The collection to be bound to `name`.
121        value: Box<HirRelationExpr>,
122        /// The result of the `Let`, evaluated with `name` bound to `value`.
123        body: Box<HirRelationExpr>,
124    },
125    Project {
126        input: Box<HirRelationExpr>,
127        outputs: Vec<usize>,
128    },
129    Map {
130        input: Box<HirRelationExpr>,
131        scalars: Vec<HirScalarExpr>,
132    },
133    CallTable {
134        func: TableFunc,
135        exprs: Vec<HirScalarExpr>,
136    },
137    Filter {
138        input: Box<HirRelationExpr>,
139        predicates: Vec<HirScalarExpr>,
140    },
141    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
142    /// joins away into more primitive exprs
143    Join {
144        left: Box<HirRelationExpr>,
145        right: Box<HirRelationExpr>,
146        on: HirScalarExpr,
147        kind: JoinKind,
148    },
149    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
150    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
151    /// rows
152    Reduce {
153        input: Box<HirRelationExpr>,
154        group_key: Vec<usize>,
155        aggregates: Vec<AggregateExpr>,
156        expected_group_size: Option<u64>,
157    },
158    Distinct {
159        input: Box<HirRelationExpr>,
160    },
161    /// Groups and orders within each group, limiting output.
162    TopK {
163        /// The source collection.
164        input: Box<HirRelationExpr>,
165        /// Column indices used to form groups.
166        group_key: Vec<usize>,
167        /// Column indices used to order rows within groups.
168        order_key: Vec<ColumnOrder>,
169        /// Number of records to retain.
170        /// It is of SqlScalarType::Int64.
171        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
172        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
173        /// is better for Postgres compat. This is because if there is a $1 here, then when external
174        /// tools `describe` the prepared statement, they discover this type. If what they find
175        /// were UInt64, then they might have trouble calling the prepared statement, because the
176        /// unsigned types are non-standard, and also don't exist even in Postgres.)
177        limit: Option<HirScalarExpr>,
178        /// Number of records to skip.
179        /// It is of SqlScalarType::Int64.
180        /// This can contain parameters at first, but by the time we reach lowering, this should
181        /// already be simply a Literal.
182        offset: HirScalarExpr,
183        /// User-supplied hint: how many rows will have the same group key.
184        expected_group_size: Option<u64>,
185    },
186    Negate {
187        input: Box<HirRelationExpr>,
188    },
189    /// Keep rows from a dataflow where the row counts are positive.
190    Threshold {
191        input: Box<HirRelationExpr>,
192    },
193    Union {
194        base: Box<HirRelationExpr>,
195        inputs: Vec<HirRelationExpr>,
196    },
197}
198
199/// Stored column metadata.
200pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
204pub enum HirScalarExpr {
205    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
206    /// variable could refer to a column of the current input, or to a column of an outer relation.
207    /// We use ColumnRef to denote the difference.
208    Column(ColumnRef, NameMetadata),
209    Parameter(usize, NameMetadata),
210    Literal(Row, SqlColumnType, NameMetadata),
211    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212    CallUnary {
213        func: UnaryFunc,
214        expr: Box<HirScalarExpr>,
215        name: NameMetadata,
216    },
217    CallBinary {
218        func: BinaryFunc,
219        expr1: Box<HirScalarExpr>,
220        expr2: Box<HirScalarExpr>,
221        name: NameMetadata,
222    },
223    CallVariadic {
224        func: VariadicFunc,
225        exprs: Vec<HirScalarExpr>,
226        name: NameMetadata,
227    },
228    If {
229        cond: Box<HirScalarExpr>,
230        then: Box<HirScalarExpr>,
231        els: Box<HirScalarExpr>,
232        name: NameMetadata,
233    },
234    /// Returns true if `expr` returns any rows
235    Exists(Box<HirRelationExpr>, NameMetadata),
236    /// Given `expr` with arity 1. If expr returns:
237    /// * 0 rows, return NULL
238    /// * 1 row, return the value of that row
239    /// * >1 rows, we return an error
240    Select(Box<HirRelationExpr>, NameMetadata),
241    Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245/// Represents the invocation of a window function over an optional partitioning with an optional
246/// order.
247pub struct WindowExpr {
248    pub func: WindowExprType,
249    pub partition_by: Vec<HirScalarExpr>,
250    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
251    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
252    ///    above,
253    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
254    /// These are separated because they are used in different places: the outer `order_by` is used
255    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
256    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
257    /// (`WindowExpr` exists only in HIR, but not in MIR.)
258    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
259    /// lowering, and not to original input columns.
260    pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265    where
266        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267    {
268        #[allow(deprecated)]
269        self.func.visit_expressions(f)?;
270        for expr in self.partition_by.iter() {
271            f(expr)?;
272        }
273        for expr in self.order_by.iter() {
274            f(expr)?;
275        }
276        Ok(())
277    }
278
279    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280    where
281        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282    {
283        #[allow(deprecated)]
284        self.func.visit_expressions_mut(f)?;
285        for expr in self.partition_by.iter_mut() {
286            f(expr)?;
287        }
288        for expr in self.order_by.iter_mut() {
289            f(expr)?;
290        }
291        Ok(())
292    }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296    fn visit_children<F>(&self, mut f: F)
297    where
298        F: FnMut(&HirScalarExpr),
299    {
300        self.func.visit_children(&mut f);
301        for expr in self.partition_by.iter() {
302            f(expr);
303        }
304        for expr in self.order_by.iter() {
305            f(expr);
306        }
307    }
308
309    fn visit_mut_children<F>(&mut self, mut f: F)
310    where
311        F: FnMut(&mut HirScalarExpr),
312    {
313        self.func.visit_mut_children(&mut f);
314        for expr in self.partition_by.iter_mut() {
315            f(expr);
316        }
317        for expr in self.order_by.iter_mut() {
318            f(expr);
319        }
320    }
321
322    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323    where
324        F: FnMut(&HirScalarExpr) -> Result<(), E>,
325        E: From<RecursionLimitError>,
326    {
327        self.func.try_visit_children(&mut f)?;
328        for expr in self.partition_by.iter() {
329            f(expr)?;
330        }
331        for expr in self.order_by.iter() {
332            f(expr)?;
333        }
334        Ok(())
335    }
336
337    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338    where
339        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340        E: From<RecursionLimitError>,
341    {
342        self.func.try_visit_mut_children(&mut f)?;
343        for expr in self.partition_by.iter_mut() {
344            f(expr)?;
345        }
346        for expr in self.order_by.iter_mut() {
347            f(expr)?;
348        }
349        Ok(())
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354/// A window function with its parameters.
355///
356/// There are three types of window functions:
357/// - scalar window functions, which return a different scalar value for each
358///   row within a partition that depends exclusively on the position of the row
359///   within the partition;
360/// - value window functions, which return a scalar value for each row within a
361///   partition that might be computed based on a single row, which is usually not
362///   the current row (e.g., previous or following row; first or last row of the
363///   partition);
364/// - aggregate window functions, which compute a traditional aggregation as a
365///   window function (e.g. `sum(x) OVER (...)`).
366///   (Aggregate window  functions can in some cases be computed by joining the
367///   input relation with a reduction over the same relation that computes the
368///   aggregation using the partition key as its grouping key, but we don't
369///   automatically do this currently.)
370pub enum WindowExprType {
371    Scalar(ScalarWindowExpr),
372    Value(ValueWindowExpr),
373    Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379    where
380        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381    {
382        #[allow(deprecated)]
383        match self {
384            Self::Scalar(expr) => expr.visit_expressions(f),
385            Self::Value(expr) => expr.visit_expressions(f),
386            Self::Aggregate(expr) => expr.visit_expressions(f),
387        }
388    }
389
390    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392    where
393        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394    {
395        #[allow(deprecated)]
396        match self {
397            Self::Scalar(expr) => expr.visit_expressions_mut(f),
398            Self::Value(expr) => expr.visit_expressions_mut(f),
399            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400        }
401    }
402
403    fn typ(
404        &self,
405        outers: &[SqlRelationType],
406        inner: &SqlRelationType,
407        params: &BTreeMap<usize, SqlScalarType>,
408    ) -> SqlColumnType {
409        match self {
410            Self::Scalar(expr) => expr.typ(outers, inner, params),
411            Self::Value(expr) => expr.typ(outers, inner, params),
412            Self::Aggregate(expr) => expr.typ(outers, inner, params),
413        }
414    }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418    fn visit_children<F>(&self, f: F)
419    where
420        F: FnMut(&HirScalarExpr),
421    {
422        match self {
423            Self::Scalar(_) => (),
424            Self::Value(expr) => expr.visit_children(f),
425            Self::Aggregate(expr) => expr.visit_children(f),
426        }
427    }
428
429    fn visit_mut_children<F>(&mut self, f: F)
430    where
431        F: FnMut(&mut HirScalarExpr),
432    {
433        match self {
434            Self::Scalar(_) => (),
435            Self::Value(expr) => expr.visit_mut_children(f),
436            Self::Aggregate(expr) => expr.visit_mut_children(f),
437        }
438    }
439
440    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441    where
442        F: FnMut(&HirScalarExpr) -> Result<(), E>,
443        E: From<RecursionLimitError>,
444    {
445        match self {
446            Self::Scalar(_) => Ok(()),
447            Self::Value(expr) => expr.try_visit_children(f),
448            Self::Aggregate(expr) => expr.try_visit_children(f),
449        }
450    }
451
452    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453    where
454        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455        E: From<RecursionLimitError>,
456    {
457        match self {
458            Self::Scalar(_) => Ok(()),
459            Self::Value(expr) => expr.try_visit_mut_children(f),
460            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461        }
462    }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467    pub func: ScalarWindowFunc,
468    pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474    where
475        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476    {
477        match self.func {
478            ScalarWindowFunc::RowNumber => {}
479            ScalarWindowFunc::Rank => {}
480            ScalarWindowFunc::DenseRank => {}
481        }
482        Ok(())
483    }
484
485    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487    where
488        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489    {
490        match self.func {
491            ScalarWindowFunc::RowNumber => {}
492            ScalarWindowFunc::Rank => {}
493            ScalarWindowFunc::DenseRank => {}
494        }
495        Ok(())
496    }
497
498    fn typ(
499        &self,
500        _outers: &[SqlRelationType],
501        _inner: &SqlRelationType,
502        _params: &BTreeMap<usize, SqlScalarType>,
503    ) -> SqlColumnType {
504        self.func.output_type()
505    }
506
507    pub fn into_expr(self) -> mz_expr::AggregateFunc {
508        match self.func {
509            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510                order_by: self.order_by,
511            },
512            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513                order_by: self.order_by,
514            },
515            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516                order_by: self.order_by,
517            },
518        }
519    }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523/// Scalar Window functions
524pub enum ScalarWindowFunc {
525    RowNumber,
526    Rank,
527    DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532        match self {
533            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534            ScalarWindowFunc::Rank => write!(f, "rank"),
535            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536        }
537    }
538}
539
540impl ScalarWindowFunc {
541    pub fn output_type(&self) -> SqlColumnType {
542        match self {
543            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
544            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
545            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
546        }
547    }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552    pub func: ValueWindowFunc,
553    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
554    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
555    /// e.g., `row(#1, 3, null)`.
556    /// If it's a fused window function, then the arguments of each of the constituent function
557    /// calls are wrapped in an outer record.
558    pub args: Box<HirScalarExpr>,
559    /// See comment on `WindowExpr::order_by`.
560    pub order_by: Vec<ColumnOrder>,
561    pub window_frame: WindowFrame,
562    pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567        match self {
568            ValueWindowFunc::Lag => write!(f, "lag"),
569            ValueWindowFunc::Lead => write!(f, "lead"),
570            ValueWindowFunc::FirstValue => write!(f, "first_value"),
571            ValueWindowFunc::LastValue => write!(f, "last_value"),
572            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573        }
574    }
575}
576
577impl ValueWindowExpr {
578    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580    where
581        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582    {
583        f(&self.args)
584    }
585
586    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588    where
589        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590    {
591        f(&mut self.args)
592    }
593
594    fn typ(
595        &self,
596        outers: &[SqlRelationType],
597        inner: &SqlRelationType,
598        params: &BTreeMap<usize, SqlScalarType>,
599    ) -> SqlColumnType {
600        self.func.output_type(self.args.typ(outers, inner, params))
601    }
602
603    /// Converts into `mz_expr::AggregateFunc`.
604    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605        (
606            self.args,
607            self.func
608                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609        )
610    }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614    fn visit_children<F>(&self, mut f: F)
615    where
616        F: FnMut(&HirScalarExpr),
617    {
618        f(&self.args)
619    }
620
621    fn visit_mut_children<F>(&mut self, mut f: F)
622    where
623        F: FnMut(&mut HirScalarExpr),
624    {
625        f(&mut self.args)
626    }
627
628    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629    where
630        F: FnMut(&HirScalarExpr) -> Result<(), E>,
631        E: From<RecursionLimitError>,
632    {
633        f(&self.args)
634    }
635
636    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637    where
638        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639        E: From<RecursionLimitError>,
640    {
641        f(&mut self.args)
642    }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646/// Value Window functions
647pub enum ValueWindowFunc {
648    Lag,
649    Lead,
650    FirstValue,
651    LastValue,
652    Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
657        match self {
658            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659                // The input is a (value, offset, default) record, so extract the type of the first arg
660                input_type.scalar_type.unwrap_record_element_type()[0]
661                    .clone()
662                    .nullable(true)
663            }
664            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665                input_type.scalar_type.nullable(true)
666            }
667            ValueWindowFunc::Fused(funcs) => {
668                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669                SqlScalarType::Record {
670                    fields: funcs
671                        .iter()
672                        .zip_eq(input_types)
673                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674                        .collect(),
675                    custom_id: None,
676                }
677                .nullable(false)
678            }
679        }
680    }
681
682    pub fn into_expr(
683        self,
684        order_by: Vec<ColumnOrder>,
685        window_frame: WindowFrame,
686        ignore_nulls: bool,
687    ) -> mz_expr::AggregateFunc {
688        match self {
689            // Lag and Lead are fundamentally the same function, just with opposite directions
690            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691                order_by,
692                lag_lead: mz_expr::LagLeadType::Lag,
693                ignore_nulls,
694            },
695            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696                order_by,
697                lag_lead: mz_expr::LagLeadType::Lead,
698                ignore_nulls,
699            },
700            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701                order_by,
702                window_frame,
703            },
704            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705                order_by,
706                window_frame,
707            },
708            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709                funcs: funcs
710                    .into_iter()
711                    .map(|func| {
712                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713                    })
714                    .collect(),
715                order_by,
716            },
717        }
718    }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723    pub aggregate_expr: AggregateExpr,
724    pub order_by: Vec<ColumnOrder>,
725    pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731    where
732        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733    {
734        f(&self.aggregate_expr.expr)
735    }
736
737    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739    where
740        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741    {
742        f(&mut self.aggregate_expr.expr)
743    }
744
745    fn typ(
746        &self,
747        outers: &[SqlRelationType],
748        inner: &SqlRelationType,
749        params: &BTreeMap<usize, SqlScalarType>,
750    ) -> SqlColumnType {
751        self.aggregate_expr
752            .func
753            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754    }
755
756    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758            (
759                self.aggregate_expr.expr,
760                FusedWindowAggregate {
761                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762                    order_by: self.order_by,
763                    window_frame: self.window_frame,
764                },
765            )
766        } else {
767            (
768                self.aggregate_expr.expr,
769                WindowAggregate {
770                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771                    order_by: self.order_by,
772                    window_frame: self.window_frame,
773                },
774            )
775        }
776    }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780    fn visit_children<F>(&self, mut f: F)
781    where
782        F: FnMut(&HirScalarExpr),
783    {
784        f(&self.aggregate_expr.expr)
785    }
786
787    fn visit_mut_children<F>(&mut self, mut f: F)
788    where
789        F: FnMut(&mut HirScalarExpr),
790    {
791        f(&mut self.aggregate_expr.expr)
792    }
793
794    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795    where
796        F: FnMut(&HirScalarExpr) -> Result<(), E>,
797        E: From<RecursionLimitError>,
798    {
799        f(&self.aggregate_expr.expr)
800    }
801
802    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803    where
804        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805        E: From<RecursionLimitError>,
806    {
807        f(&mut self.aggregate_expr.expr)
808    }
809}
810
811/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
812/// determined. Several SQL expressions can be freely coerced based upon where
813/// in the expression tree they appear. For example, the string literal '42'
814/// will be automatically coerced to the integer 42 if used in a numeric
815/// context:
816///
817/// ```sql
818/// SELECT '42' + 42
819/// ```
820///
821/// This separate type gives the code that needs to interact with coercions very
822/// fine-grained control over what coercions happen and when.
823///
824/// The primary driver of coercion is function and operator selection, as
825/// choosing the correct function or operator implementation depends on the type
826/// of the provided arguments. Coercion also occurs at the very root of the
827/// scalar expression tree. For example in
828///
829/// ```sql
830/// SELECT ... WHERE $1
831/// ```
832///
833/// the `WHERE` clause will coerce the contained unconstrained type parameter
834/// `$1` to have type bool.
835#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837    Coerced(HirScalarExpr),
838    Parameter(usize),
839    LiteralNull,
840    LiteralString(String),
841    LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845    pub fn type_as(
846        self,
847        ecx: &ExprContext,
848        ty: &SqlScalarType,
849    ) -> Result<HirScalarExpr, PlanError> {
850        let expr = typeconv::plan_coerce(ecx, self, ty)?;
851        let expr_ty = ecx.scalar_type(&expr);
852        if ty != &expr_ty {
853            sql_bail!(
854                "{} must have type {}, not type {}",
855                ecx.name,
856                ecx.humanize_scalar_type(ty, false),
857                ecx.humanize_scalar_type(&expr_ty, false),
858            );
859        }
860        Ok(expr)
861    }
862
863    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
864        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
865    }
866
867    pub fn cast_to(
868        self,
869        ecx: &ExprContext,
870        ccx: CastContext,
871        ty: &SqlScalarType,
872    ) -> Result<HirScalarExpr, PlanError> {
873        let expr = typeconv::plan_coerce(ecx, self, ty)?;
874        typeconv::plan_cast(ecx, ccx, expr, ty)
875    }
876}
877
878/// The column type for a [`CoercibleScalarExpr`].
879#[derive(Clone, Debug)]
880pub enum CoercibleColumnType {
881    Coerced(SqlColumnType),
882    Record(Vec<CoercibleColumnType>),
883    Uncoerced,
884}
885
886impl CoercibleColumnType {
887    /// Reports the nullability of the type.
888    pub fn nullable(&self) -> bool {
889        match self {
890            // A coerced value's nullability is known.
891            CoercibleColumnType::Coerced(ct) => ct.nullable,
892
893            // A literal record can never be null.
894            CoercibleColumnType::Record(_) => false,
895
896            // An uncoerced literal may be the literal `NULL`, so we have
897            // to conservatively assume it is nullable.
898            CoercibleColumnType::Uncoerced => true,
899        }
900    }
901}
902
903/// The scalar type for a [`CoercibleScalarExpr`].
904#[derive(Clone, Debug)]
905pub enum CoercibleScalarType {
906    Coerced(SqlScalarType),
907    Record(Vec<CoercibleColumnType>),
908    Uncoerced,
909}
910
911impl CoercibleScalarType {
912    /// Reports whether the scalar type has been coerced.
913    pub fn is_coerced(&self) -> bool {
914        matches!(self, CoercibleScalarType::Coerced(_))
915    }
916
917    /// Returns the coerced scalar type, if the type is coerced.
918    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
919        match self {
920            CoercibleScalarType::Coerced(t) => Some(t),
921            _ => None,
922        }
923    }
924
925    /// If the type is coerced, apply the mapping function to the contained
926    /// scalar type.
927    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
928    where
929        F: FnOnce(SqlScalarType) -> SqlScalarType,
930    {
931        match self {
932            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
933            _ => self,
934        }
935    }
936
937    /// If the type is an coercible record, forcibly converts to a coerced
938    /// record type. Any uncoerced field types are assumed to be of type text.
939    ///
940    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
941    /// accepts a type hint that can indicate the types of uncoerced field
942    /// types.
943    pub fn force_coerced_if_record(&mut self) {
944        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
945            let mut fields = vec![];
946            for (i, uf) in uncoerced_fields.enumerate() {
947                let name = ColumnName::from(format!("f{}", i + 1));
948                let ty = match uf {
949                    CoercibleColumnType::Coerced(ty) => ty,
950                    CoercibleColumnType::Record(mut fields) => {
951                        convert(fields.drain(..)).nullable(false)
952                    }
953                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
954                };
955                fields.push((name, ty))
956            }
957            SqlScalarType::Record {
958                fields: fields.into(),
959                custom_id: None,
960            }
961        }
962
963        if let CoercibleScalarType::Record(fields) = self {
964            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
965        }
966    }
967}
968
969/// An expression whose type can be ascertained.
970///
971/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
972pub trait AbstractExpr {
973    type Type: AbstractColumnType;
974
975    /// Computes the type of the expression.
976    fn typ(
977        &self,
978        outers: &[SqlRelationType],
979        inner: &SqlRelationType,
980        params: &BTreeMap<usize, SqlScalarType>,
981    ) -> Self::Type;
982}
983
984impl AbstractExpr for CoercibleScalarExpr {
985    type Type = CoercibleColumnType;
986
987    fn typ(
988        &self,
989        outers: &[SqlRelationType],
990        inner: &SqlRelationType,
991        params: &BTreeMap<usize, SqlScalarType>,
992    ) -> Self::Type {
993        match self {
994            CoercibleScalarExpr::Coerced(expr) => {
995                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
996            }
997            CoercibleScalarExpr::LiteralRecord(scalars) => {
998                let fields = scalars
999                    .iter()
1000                    .map(|s| s.typ(outers, inner, params))
1001                    .collect();
1002                CoercibleColumnType::Record(fields)
1003            }
1004            _ => CoercibleColumnType::Uncoerced,
1005        }
1006    }
1007}
1008
1009/// A column type-like object whose underlying scalar type-like object can be
1010/// ascertained.
1011///
1012/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1013pub trait AbstractColumnType {
1014    type AbstractScalarType;
1015
1016    /// Converts the column type-like object into its inner scalar type-like
1017    /// object.
1018    fn scalar_type(self) -> Self::AbstractScalarType;
1019}
1020
1021impl AbstractColumnType for SqlColumnType {
1022    type AbstractScalarType = SqlScalarType;
1023
1024    fn scalar_type(self) -> Self::AbstractScalarType {
1025        self.scalar_type
1026    }
1027}
1028
1029impl AbstractColumnType for CoercibleColumnType {
1030    type AbstractScalarType = CoercibleScalarType;
1031
1032    fn scalar_type(self) -> Self::AbstractScalarType {
1033        match self {
1034            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1035            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1036            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1037        }
1038    }
1039}
1040
1041impl From<HirScalarExpr> for CoercibleScalarExpr {
1042    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1043        CoercibleScalarExpr::Coerced(expr)
1044    }
1045}
1046
1047/// A leveled column reference.
1048///
1049/// In the course of decorrelation, multiple levels of nested subqueries are
1050/// traversed, and references to columns may correspond to different levels
1051/// of containing outer subqueries.
1052///
1053/// A `ColumnRef` allows expressions to refer to columns while being clear
1054/// about which level the column references without manually performing the
1055/// bookkeeping tracking their actual column locations.
1056///
1057/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1058/// from the reference, using `column` as a unique identifier in that subquery level.
1059/// A `level` of zero corresponds to the current scope, and levels increase to
1060/// indicate subqueries further "outwards".
1061#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1062pub struct ColumnRef {
1063    // scope level, where 0 is the current scope and 1+ are outer scopes.
1064    pub level: usize,
1065    // level-local column identifier used.
1066    pub column: usize,
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1070pub enum JoinKind {
1071    Inner,
1072    LeftOuter,
1073    RightOuter,
1074    FullOuter,
1075}
1076
1077impl fmt::Display for JoinKind {
1078    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1079        write!(
1080            f,
1081            "{}",
1082            match self {
1083                JoinKind::Inner => "Inner",
1084                JoinKind::LeftOuter => "LeftOuter",
1085                JoinKind::RightOuter => "RightOuter",
1086                JoinKind::FullOuter => "FullOuter",
1087            }
1088        )
1089    }
1090}
1091
1092impl JoinKind {
1093    pub fn can_be_correlated(&self) -> bool {
1094        match self {
1095            JoinKind::Inner | JoinKind::LeftOuter => true,
1096            JoinKind::RightOuter | JoinKind::FullOuter => false,
1097        }
1098    }
1099}
1100
1101#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1102pub struct AggregateExpr {
1103    pub func: AggregateFunc,
1104    pub expr: Box<HirScalarExpr>,
1105    pub distinct: bool,
1106}
1107
1108/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1109/// types may be different.
1110///
1111/// Specifically, the nullability of the aggregate columns is more common
1112/// here than in `expr`, as these aggregates may be applied over empty
1113/// result sets and should be null in those cases, whereas `expr` variants
1114/// only return null values when supplied nulls as input.
1115#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1116pub enum AggregateFunc {
1117    MaxNumeric,
1118    MaxInt16,
1119    MaxInt32,
1120    MaxInt64,
1121    MaxUInt16,
1122    MaxUInt32,
1123    MaxUInt64,
1124    MaxMzTimestamp,
1125    MaxFloat32,
1126    MaxFloat64,
1127    MaxBool,
1128    MaxString,
1129    MaxDate,
1130    MaxTimestamp,
1131    MaxTimestampTz,
1132    MaxInterval,
1133    MaxTime,
1134    MinNumeric,
1135    MinInt16,
1136    MinInt32,
1137    MinInt64,
1138    MinUInt16,
1139    MinUInt32,
1140    MinUInt64,
1141    MinMzTimestamp,
1142    MinFloat32,
1143    MinFloat64,
1144    MinBool,
1145    MinString,
1146    MinDate,
1147    MinTimestamp,
1148    MinTimestampTz,
1149    MinInterval,
1150    MinTime,
1151    SumInt16,
1152    SumInt32,
1153    SumInt64,
1154    SumUInt16,
1155    SumUInt32,
1156    SumUInt64,
1157    SumFloat32,
1158    SumFloat64,
1159    SumNumeric,
1160    Count,
1161    Any,
1162    All,
1163    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1164    /// into a JSON list. The other elements are columns used by `order_by`.
1165    ///
1166    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1167    /// layer, this function filters out `Datum::Null`, for consistency with
1168    /// the other aggregate functions.
1169    JsonbAgg {
1170        order_by: Vec<ColumnOrder>,
1171    },
1172    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1173    /// JSON map. The other elements are columns used by `order_by`.
1174    JsonbObjectAgg {
1175        order_by: Vec<ColumnOrder>,
1176    },
1177    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1178    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1179    /// elements are columns used by `order_by`.
1180    MapAgg {
1181        order_by: Vec<ColumnOrder>,
1182        value_type: SqlScalarType,
1183    },
1184    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1185    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1186    ArrayConcat {
1187        order_by: Vec<ColumnOrder>,
1188    },
1189    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1190    /// single `Datum::List`. The other elements are columns used by `order_by`.
1191    ListConcat {
1192        order_by: Vec<ColumnOrder>,
1193    },
1194    StringAgg {
1195        order_by: Vec<ColumnOrder>,
1196    },
1197    /// A bundle of fused window aggregations: its input is a record, whose each
1198    /// component will be the input to one of the `AggregateFunc`s.
1199    ///
1200    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1201    /// more specifically an `AggregateWindowExpr`.
1202    FusedWindowAgg {
1203        funcs: Vec<AggregateFunc>,
1204    },
1205    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1206    ///
1207    /// Useful for removing an expensive aggregation while maintaining the shape
1208    /// of a reduce operator.
1209    Dummy,
1210}
1211
1212impl AggregateFunc {
1213    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1214    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1215        match self {
1216            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1217            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1218            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1219            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1220            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1221            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1222            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1223            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1224            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1225            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1226            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1227            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1228            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1229            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1230            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1231            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1232            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1233            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1234            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1235            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1236            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1237            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1238            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1239            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1240            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1241            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1242            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1243            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1244            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1245            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1246            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1247            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1248            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1249            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1250            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1251            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1252            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1253            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1254            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1255            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1256            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1257            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1258            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1259            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1260            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1261            AggregateFunc::All => mz_expr::AggregateFunc::All,
1262            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1263            AggregateFunc::JsonbObjectAgg { order_by } => {
1264                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1265            }
1266            AggregateFunc::MapAgg {
1267                order_by,
1268                value_type,
1269            } => mz_expr::AggregateFunc::MapAgg {
1270                order_by,
1271                value_type,
1272            },
1273            AggregateFunc::ArrayConcat { order_by } => {
1274                mz_expr::AggregateFunc::ArrayConcat { order_by }
1275            }
1276            AggregateFunc::ListConcat { order_by } => {
1277                mz_expr::AggregateFunc::ListConcat { order_by }
1278            }
1279            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1280            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1281            // `AggregateWindowExpr::into_expr`.
1282            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1283                panic!("into_expr called on FusedWindowAgg")
1284            }
1285            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1286        }
1287    }
1288
1289    /// Returns a datum whose inclusion in the aggregation will not change its
1290    /// result.
1291    ///
1292    /// # Panics
1293    ///
1294    /// Panics if called on a `FusedWindowAgg`.
1295    pub fn identity_datum(&self) -> Datum<'static> {
1296        match self {
1297            AggregateFunc::Any => Datum::False,
1298            AggregateFunc::All => Datum::True,
1299            AggregateFunc::Dummy => Datum::Dummy,
1300            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1301            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1302            AggregateFunc::MaxNumeric
1303            | AggregateFunc::MaxInt16
1304            | AggregateFunc::MaxInt32
1305            | AggregateFunc::MaxInt64
1306            | AggregateFunc::MaxUInt16
1307            | AggregateFunc::MaxUInt32
1308            | AggregateFunc::MaxUInt64
1309            | AggregateFunc::MaxMzTimestamp
1310            | AggregateFunc::MaxFloat32
1311            | AggregateFunc::MaxFloat64
1312            | AggregateFunc::MaxBool
1313            | AggregateFunc::MaxString
1314            | AggregateFunc::MaxDate
1315            | AggregateFunc::MaxTimestamp
1316            | AggregateFunc::MaxTimestampTz
1317            | AggregateFunc::MaxInterval
1318            | AggregateFunc::MaxTime
1319            | AggregateFunc::MinNumeric
1320            | AggregateFunc::MinInt16
1321            | AggregateFunc::MinInt32
1322            | AggregateFunc::MinInt64
1323            | AggregateFunc::MinUInt16
1324            | AggregateFunc::MinUInt32
1325            | AggregateFunc::MinUInt64
1326            | AggregateFunc::MinMzTimestamp
1327            | AggregateFunc::MinFloat32
1328            | AggregateFunc::MinFloat64
1329            | AggregateFunc::MinBool
1330            | AggregateFunc::MinString
1331            | AggregateFunc::MinDate
1332            | AggregateFunc::MinTimestamp
1333            | AggregateFunc::MinTimestampTz
1334            | AggregateFunc::MinInterval
1335            | AggregateFunc::MinTime
1336            | AggregateFunc::SumInt16
1337            | AggregateFunc::SumInt32
1338            | AggregateFunc::SumInt64
1339            | AggregateFunc::SumUInt16
1340            | AggregateFunc::SumUInt32
1341            | AggregateFunc::SumUInt64
1342            | AggregateFunc::SumFloat32
1343            | AggregateFunc::SumFloat64
1344            | AggregateFunc::SumNumeric
1345            | AggregateFunc::Count
1346            | AggregateFunc::JsonbAgg { .. }
1347            | AggregateFunc::JsonbObjectAgg { .. }
1348            | AggregateFunc::MapAgg { .. }
1349            | AggregateFunc::StringAgg { .. } => Datum::Null,
1350            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1351                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1352                // in HIR planning, because it is introduced only during HIR transformation.
1353                //
1354                // The implementation could be something like the following, except that we need to
1355                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1356                // ```
1357                // let temp_storage = RowArena::new();
1358                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1359                // ```
1360                panic!("FusedWindowAgg doesn't have an identity_datum")
1361            }
1362        }
1363    }
1364
1365    /// The output column type for the result of an aggregation.
1366    ///
1367    /// The output column type also contains nullability information, which
1368    /// is (without further information) true for aggregations that are not
1369    /// counts.
1370    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1371        let scalar_type = match self {
1372            AggregateFunc::Count => SqlScalarType::Int64,
1373            AggregateFunc::Any => SqlScalarType::Bool,
1374            AggregateFunc::All => SqlScalarType::Bool,
1375            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1376            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1377            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1378            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1379            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1380                max_scale: Some(NumericMaxScale::ZERO),
1381            },
1382            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1383            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1384                max_scale: Some(NumericMaxScale::ZERO),
1385            },
1386            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1387                value_type: Box::new(value_type.clone()),
1388                custom_id: None,
1389            },
1390            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1391                match input_type.scalar_type {
1392                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1393                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1394                    _ => unreachable!(),
1395                }
1396            }
1397            AggregateFunc::MaxNumeric
1398            | AggregateFunc::MaxInt16
1399            | AggregateFunc::MaxInt32
1400            | AggregateFunc::MaxInt64
1401            | AggregateFunc::MaxUInt16
1402            | AggregateFunc::MaxUInt32
1403            | AggregateFunc::MaxUInt64
1404            | AggregateFunc::MaxMzTimestamp
1405            | AggregateFunc::MaxFloat32
1406            | AggregateFunc::MaxFloat64
1407            | AggregateFunc::MaxBool
1408            | AggregateFunc::MaxString
1409            | AggregateFunc::MaxDate
1410            | AggregateFunc::MaxTimestamp
1411            | AggregateFunc::MaxTimestampTz
1412            | AggregateFunc::MaxInterval
1413            | AggregateFunc::MaxTime
1414            | AggregateFunc::MinNumeric
1415            | AggregateFunc::MinInt16
1416            | AggregateFunc::MinInt32
1417            | AggregateFunc::MinInt64
1418            | AggregateFunc::MinUInt16
1419            | AggregateFunc::MinUInt32
1420            | AggregateFunc::MinUInt64
1421            | AggregateFunc::MinMzTimestamp
1422            | AggregateFunc::MinFloat32
1423            | AggregateFunc::MinFloat64
1424            | AggregateFunc::MinBool
1425            | AggregateFunc::MinString
1426            | AggregateFunc::MinDate
1427            | AggregateFunc::MinTimestamp
1428            | AggregateFunc::MinTimestampTz
1429            | AggregateFunc::MinInterval
1430            | AggregateFunc::MinTime
1431            | AggregateFunc::SumFloat32
1432            | AggregateFunc::SumFloat64
1433            | AggregateFunc::SumNumeric
1434            | AggregateFunc::Dummy => input_type.scalar_type,
1435            AggregateFunc::FusedWindowAgg { funcs } => {
1436                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1437                SqlScalarType::Record {
1438                    fields: funcs
1439                        .iter()
1440                        .zip_eq(input_types)
1441                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1442                        .collect(),
1443                    custom_id: None,
1444                }
1445            }
1446        };
1447        // max/min/sum return null on empty sets
1448        let nullable = !matches!(self, AggregateFunc::Count);
1449        scalar_type.nullable(nullable)
1450    }
1451
1452    pub fn is_order_sensitive(&self) -> bool {
1453        use AggregateFunc::*;
1454        matches!(
1455            self,
1456            JsonbAgg { .. }
1457                | JsonbObjectAgg { .. }
1458                | MapAgg { .. }
1459                | ArrayConcat { .. }
1460                | ListConcat { .. }
1461                | StringAgg { .. }
1462        )
1463    }
1464}
1465
1466impl HirRelationExpr {
1467    pub fn typ(
1468        &self,
1469        outers: &[SqlRelationType],
1470        params: &BTreeMap<usize, SqlScalarType>,
1471    ) -> SqlRelationType {
1472        stack::maybe_grow(|| match self {
1473            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1474            HirRelationExpr::Get { typ, .. } => typ.clone(),
1475            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1476            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1477            HirRelationExpr::Project { input, outputs } => {
1478                let input_typ = input.typ(outers, params);
1479                SqlRelationType::new(
1480                    outputs
1481                        .iter()
1482                        .map(|&i| input_typ.column_types[i].clone())
1483                        .collect(),
1484                )
1485            }
1486            HirRelationExpr::Map { input, scalars } => {
1487                let mut typ = input.typ(outers, params);
1488                for scalar in scalars {
1489                    typ.column_types.push(scalar.typ(outers, &typ, params));
1490                }
1491                typ
1492            }
1493            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1494            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1495                input.typ(outers, params)
1496            }
1497            HirRelationExpr::Join {
1498                left, right, kind, ..
1499            } => {
1500                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1501                let right_nullable =
1502                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1503                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1504                    let nullable = t.nullable || left_nullable;
1505                    t.nullable(nullable)
1506                });
1507                let mut outers = outers.to_vec();
1508                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1509                let rt = right
1510                    .typ(&outers, params)
1511                    .column_types
1512                    .into_iter()
1513                    .map(|t| {
1514                        let nullable = t.nullable || right_nullable;
1515                        t.nullable(nullable)
1516                    });
1517                SqlRelationType::new(lt.chain(rt).collect())
1518            }
1519            HirRelationExpr::Reduce {
1520                input,
1521                group_key,
1522                aggregates,
1523                expected_group_size: _,
1524            } => {
1525                let input_typ = input.typ(outers, params);
1526                let mut column_types = group_key
1527                    .iter()
1528                    .map(|&i| input_typ.column_types[i].clone())
1529                    .collect::<Vec<_>>();
1530                for agg in aggregates {
1531                    column_types.push(agg.typ(outers, &input_typ, params));
1532                }
1533                // TODO(frank): add primary key information.
1534                SqlRelationType::new(column_types)
1535            }
1536            // TODO(frank): check for removal; add primary key information.
1537            HirRelationExpr::Distinct { input }
1538            | HirRelationExpr::Negate { input }
1539            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1540            HirRelationExpr::Union { base, inputs } => {
1541                let mut base_cols = base.typ(outers, params).column_types;
1542                for input in inputs {
1543                    for (base_col, col) in base_cols
1544                        .iter_mut()
1545                        .zip_eq(input.typ(outers, params).column_types)
1546                    {
1547                        *base_col = base_col.union(&col).unwrap();
1548                    }
1549                }
1550                SqlRelationType::new(base_cols)
1551            }
1552        })
1553    }
1554
1555    pub fn arity(&self) -> usize {
1556        match self {
1557            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1558            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1559            HirRelationExpr::Let { body, .. } => body.arity(),
1560            HirRelationExpr::LetRec { body, .. } => body.arity(),
1561            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1562            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1563            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1564            HirRelationExpr::Filter { input, .. }
1565            | HirRelationExpr::TopK { input, .. }
1566            | HirRelationExpr::Distinct { input }
1567            | HirRelationExpr::Negate { input }
1568            | HirRelationExpr::Threshold { input } => input.arity(),
1569            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1570            HirRelationExpr::Union { base, .. } => base.arity(),
1571            HirRelationExpr::Reduce {
1572                group_key,
1573                aggregates,
1574                ..
1575            } => group_key.len() + aggregates.len(),
1576        }
1577    }
1578
1579    /// If self is a constant, return the value and the type, otherwise `None`.
1580    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1581        match self {
1582            Self::Constant { rows, typ } => Some((rows, typ)),
1583            _ => None,
1584        }
1585    }
1586
1587    /// Reports whether this expression contains a column reference to its
1588    /// direct parent scope.
1589    pub fn is_correlated(&self) -> bool {
1590        let mut correlated = false;
1591        #[allow(deprecated)]
1592        self.visit_columns(0, &mut |depth, col| {
1593            if col.level > depth && col.level - depth == 1 {
1594                correlated = true;
1595            }
1596        });
1597        correlated
1598    }
1599
1600    pub fn is_join_identity(&self) -> bool {
1601        match self {
1602            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1603            _ => false,
1604        }
1605    }
1606
1607    pub fn project(self, outputs: Vec<usize>) -> Self {
1608        if outputs.iter().copied().eq(0..self.arity()) {
1609            // The projection is trivial. Suppress it.
1610            self
1611        } else {
1612            HirRelationExpr::Project {
1613                input: Box::new(self),
1614                outputs,
1615            }
1616        }
1617    }
1618
1619    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1620        if scalars.is_empty() {
1621            // The map is trivial. Suppress it.
1622            self
1623        } else if let HirRelationExpr::Map {
1624            scalars: old_scalars,
1625            input: _,
1626        } = &mut self
1627        {
1628            // Map applied to a map. Fuse the maps.
1629            old_scalars.extend(scalars);
1630            self
1631        } else {
1632            HirRelationExpr::Map {
1633                input: Box::new(self),
1634                scalars,
1635            }
1636        }
1637    }
1638
1639    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1640        if let HirRelationExpr::Filter {
1641            input: _,
1642            predicates,
1643        } = &mut self
1644        {
1645            predicates.extend(preds);
1646            predicates.sort();
1647            predicates.dedup();
1648            self
1649        } else {
1650            preds.sort();
1651            preds.dedup();
1652            HirRelationExpr::Filter {
1653                input: Box::new(self),
1654                predicates: preds,
1655            }
1656        }
1657    }
1658
1659    pub fn reduce(
1660        self,
1661        group_key: Vec<usize>,
1662        aggregates: Vec<AggregateExpr>,
1663        expected_group_size: Option<u64>,
1664    ) -> Self {
1665        HirRelationExpr::Reduce {
1666            input: Box::new(self),
1667            group_key,
1668            aggregates,
1669            expected_group_size,
1670        }
1671    }
1672
1673    pub fn top_k(
1674        self,
1675        group_key: Vec<usize>,
1676        order_key: Vec<ColumnOrder>,
1677        limit: Option<HirScalarExpr>,
1678        offset: HirScalarExpr,
1679        expected_group_size: Option<u64>,
1680    ) -> Self {
1681        HirRelationExpr::TopK {
1682            input: Box::new(self),
1683            group_key,
1684            order_key,
1685            limit,
1686            offset,
1687            expected_group_size,
1688        }
1689    }
1690
1691    pub fn negate(self) -> Self {
1692        if let HirRelationExpr::Negate { input } = self {
1693            *input
1694        } else {
1695            HirRelationExpr::Negate {
1696                input: Box::new(self),
1697            }
1698        }
1699    }
1700
1701    pub fn distinct(self) -> Self {
1702        if let HirRelationExpr::Distinct { .. } = self {
1703            self
1704        } else {
1705            HirRelationExpr::Distinct {
1706                input: Box::new(self),
1707            }
1708        }
1709    }
1710
1711    pub fn threshold(self) -> Self {
1712        if let HirRelationExpr::Threshold { .. } = self {
1713            self
1714        } else {
1715            HirRelationExpr::Threshold {
1716                input: Box::new(self),
1717            }
1718        }
1719    }
1720
1721    pub fn union(self, other: Self) -> Self {
1722        let mut terms = Vec::new();
1723        if let HirRelationExpr::Union { base, inputs } = self {
1724            terms.push(*base);
1725            terms.extend(inputs);
1726        } else {
1727            terms.push(self);
1728        }
1729        if let HirRelationExpr::Union { base, inputs } = other {
1730            terms.push(*base);
1731            terms.extend(inputs);
1732        } else {
1733            terms.push(other);
1734        }
1735        HirRelationExpr::Union {
1736            base: Box::new(terms.remove(0)),
1737            inputs: terms,
1738        }
1739    }
1740
1741    pub fn exists(self) -> HirScalarExpr {
1742        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1743    }
1744
1745    pub fn select(self) -> HirScalarExpr {
1746        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1747    }
1748
1749    pub fn join(
1750        self,
1751        mut right: HirRelationExpr,
1752        on: HirScalarExpr,
1753        kind: JoinKind,
1754    ) -> HirRelationExpr {
1755        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1756        {
1757            // The join can be elided, but we need to adjust column references
1758            // on the right-hand side to account for the removal of the scope
1759            // introduced by the join.
1760            #[allow(deprecated)]
1761            right.visit_columns_mut(0, &mut |depth, col| {
1762                if col.level > depth {
1763                    col.level -= 1;
1764                }
1765            });
1766            right
1767        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1768            self
1769        } else {
1770            HirRelationExpr::Join {
1771                left: Box::new(self),
1772                right: Box::new(right),
1773                on,
1774                kind,
1775            }
1776        }
1777    }
1778
1779    pub fn take(&mut self) -> HirRelationExpr {
1780        mem::replace(
1781            self,
1782            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1783        )
1784    }
1785
1786    #[deprecated = "Use `Visit::visit_post`."]
1787    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1788    where
1789        F: FnMut(&'a Self, usize),
1790    {
1791        #[allow(deprecated)]
1792        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1793                                                 depth: usize|
1794         -> Result<(), ()> {
1795            f(e, depth);
1796            Ok(())
1797        });
1798    }
1799
1800    #[deprecated = "Use `Visit::try_visit_post`."]
1801    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1802    where
1803        F: FnMut(&'a Self, usize) -> Result<(), E>,
1804    {
1805        #[allow(deprecated)]
1806        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1807            e.visit_fallible(depth, f)
1808        })?;
1809        f(self, depth)
1810    }
1811
1812    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1813    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1814    where
1815        F: FnMut(&'a Self, usize) -> Result<(), E>,
1816    {
1817        match self {
1818            HirRelationExpr::Constant { .. }
1819            | HirRelationExpr::Get { .. }
1820            | HirRelationExpr::CallTable { .. } => (),
1821            HirRelationExpr::Let { body, value, .. } => {
1822                f(value, depth)?;
1823                f(body, depth)?;
1824            }
1825            HirRelationExpr::LetRec {
1826                limit: _,
1827                bindings,
1828                body,
1829            } => {
1830                for (_, _, value, _) in bindings.iter() {
1831                    f(value, depth)?;
1832                }
1833                f(body, depth)?;
1834            }
1835            HirRelationExpr::Project { input, .. } => {
1836                f(input, depth)?;
1837            }
1838            HirRelationExpr::Map { input, .. } => {
1839                f(input, depth)?;
1840            }
1841            HirRelationExpr::Filter { input, .. } => {
1842                f(input, depth)?;
1843            }
1844            HirRelationExpr::Join { left, right, .. } => {
1845                f(left, depth)?;
1846                f(right, depth + 1)?;
1847            }
1848            HirRelationExpr::Reduce { input, .. } => {
1849                f(input, depth)?;
1850            }
1851            HirRelationExpr::Distinct { input } => {
1852                f(input, depth)?;
1853            }
1854            HirRelationExpr::TopK { input, .. } => {
1855                f(input, depth)?;
1856            }
1857            HirRelationExpr::Negate { input } => {
1858                f(input, depth)?;
1859            }
1860            HirRelationExpr::Threshold { input } => {
1861                f(input, depth)?;
1862            }
1863            HirRelationExpr::Union { base, inputs } => {
1864                f(base, depth)?;
1865                for input in inputs {
1866                    f(input, depth)?;
1867                }
1868            }
1869        }
1870        Ok(())
1871    }
1872
1873    #[deprecated = "Use `Visit::visit_mut_post` instead."]
1874    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1875    where
1876        F: FnMut(&mut Self, usize),
1877    {
1878        #[allow(deprecated)]
1879        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1880                                                     depth: usize|
1881         -> Result<(), ()> {
1882            f(e, depth);
1883            Ok(())
1884        });
1885    }
1886
1887    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1888    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1889    where
1890        F: FnMut(&mut Self, usize) -> Result<(), E>,
1891    {
1892        #[allow(deprecated)]
1893        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1894            e.visit_mut_fallible(depth, f)
1895        })?;
1896        f(self, depth)
1897    }
1898
1899    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1900    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1901    where
1902        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1903    {
1904        match self {
1905            HirRelationExpr::Constant { .. }
1906            | HirRelationExpr::Get { .. }
1907            | HirRelationExpr::CallTable { .. } => (),
1908            HirRelationExpr::Let { body, value, .. } => {
1909                f(value, depth)?;
1910                f(body, depth)?;
1911            }
1912            HirRelationExpr::LetRec {
1913                limit: _,
1914                bindings,
1915                body,
1916            } => {
1917                for (_, _, value, _) in bindings.iter_mut() {
1918                    f(value, depth)?;
1919                }
1920                f(body, depth)?;
1921            }
1922            HirRelationExpr::Project { input, .. } => {
1923                f(input, depth)?;
1924            }
1925            HirRelationExpr::Map { input, .. } => {
1926                f(input, depth)?;
1927            }
1928            HirRelationExpr::Filter { input, .. } => {
1929                f(input, depth)?;
1930            }
1931            HirRelationExpr::Join { left, right, .. } => {
1932                f(left, depth)?;
1933                f(right, depth + 1)?;
1934            }
1935            HirRelationExpr::Reduce { input, .. } => {
1936                f(input, depth)?;
1937            }
1938            HirRelationExpr::Distinct { input } => {
1939                f(input, depth)?;
1940            }
1941            HirRelationExpr::TopK { input, .. } => {
1942                f(input, depth)?;
1943            }
1944            HirRelationExpr::Negate { input } => {
1945                f(input, depth)?;
1946            }
1947            HirRelationExpr::Threshold { input } => {
1948                f(input, depth)?;
1949            }
1950            HirRelationExpr::Union { base, inputs } => {
1951                f(base, depth)?;
1952                for input in inputs {
1953                    f(input, depth)?;
1954                }
1955            }
1956        }
1957        Ok(())
1958    }
1959
1960    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1961    /// Visits all scalar expressions within the sub-tree of the given relation.
1962    ///
1963    /// The `depth` argument should indicate the subquery nesting depth of the expression,
1964    /// which will be incremented when entering the RHS of a join or a subquery and
1965    /// presented to the supplied function `f`.
1966    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1967    where
1968        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1969    {
1970        #[allow(deprecated)]
1971        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1972                                         depth: usize|
1973         -> Result<(), E> {
1974            match e {
1975                HirRelationExpr::Join { on, .. } => {
1976                    f(on, depth)?;
1977                }
1978                HirRelationExpr::Map { scalars, .. } => {
1979                    for scalar in scalars {
1980                        f(scalar, depth)?;
1981                    }
1982                }
1983                HirRelationExpr::CallTable { exprs, .. } => {
1984                    for expr in exprs {
1985                        f(expr, depth)?;
1986                    }
1987                }
1988                HirRelationExpr::Filter { predicates, .. } => {
1989                    for predicate in predicates {
1990                        f(predicate, depth)?;
1991                    }
1992                }
1993                HirRelationExpr::Reduce { aggregates, .. } => {
1994                    for aggregate in aggregates {
1995                        f(&aggregate.expr, depth)?;
1996                    }
1997                }
1998                HirRelationExpr::TopK { limit, offset, .. } => {
1999                    if let Some(limit) = limit {
2000                        f(limit, depth)?;
2001                    }
2002                    f(offset, depth)?;
2003                }
2004                HirRelationExpr::Union { .. }
2005                | HirRelationExpr::Let { .. }
2006                | HirRelationExpr::LetRec { .. }
2007                | HirRelationExpr::Project { .. }
2008                | HirRelationExpr::Distinct { .. }
2009                | HirRelationExpr::Negate { .. }
2010                | HirRelationExpr::Threshold { .. }
2011                | HirRelationExpr::Constant { .. }
2012                | HirRelationExpr::Get { .. } => (),
2013            }
2014            Ok(())
2015        })
2016    }
2017
2018    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2019    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2020    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2021    where
2022        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2023    {
2024        #[allow(deprecated)]
2025        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2026                                             depth: usize|
2027         -> Result<(), E> {
2028            match e {
2029                HirRelationExpr::Join { on, .. } => {
2030                    f(on, depth)?;
2031                }
2032                HirRelationExpr::Map { scalars, .. } => {
2033                    for scalar in scalars.iter_mut() {
2034                        f(scalar, depth)?;
2035                    }
2036                }
2037                HirRelationExpr::CallTable { exprs, .. } => {
2038                    for expr in exprs.iter_mut() {
2039                        f(expr, depth)?;
2040                    }
2041                }
2042                HirRelationExpr::Filter { predicates, .. } => {
2043                    for predicate in predicates.iter_mut() {
2044                        f(predicate, depth)?;
2045                    }
2046                }
2047                HirRelationExpr::Reduce { aggregates, .. } => {
2048                    for aggregate in aggregates.iter_mut() {
2049                        f(&mut aggregate.expr, depth)?;
2050                    }
2051                }
2052                HirRelationExpr::TopK { limit, offset, .. } => {
2053                    if let Some(limit) = limit {
2054                        f(limit, depth)?;
2055                    }
2056                    f(offset, depth)?;
2057                }
2058                HirRelationExpr::Union { .. }
2059                | HirRelationExpr::Let { .. }
2060                | HirRelationExpr::LetRec { .. }
2061                | HirRelationExpr::Project { .. }
2062                | HirRelationExpr::Distinct { .. }
2063                | HirRelationExpr::Negate { .. }
2064                | HirRelationExpr::Threshold { .. }
2065                | HirRelationExpr::Constant { .. }
2066                | HirRelationExpr::Get { .. } => (),
2067            }
2068            Ok(())
2069        })
2070    }
2071
2072    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2073    /// Visits the column references in this relation expression.
2074    ///
2075    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2076    /// which will be incremented when entering the RHS of a join or a subquery and
2077    /// presented to the supplied function `f`.
2078    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2079    where
2080        F: FnMut(usize, &ColumnRef),
2081    {
2082        #[allow(deprecated)]
2083        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2084                                                           depth: usize|
2085         -> Result<(), ()> {
2086            e.visit_columns(depth, f);
2087            Ok(())
2088        });
2089    }
2090
2091    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2092    /// Like `visit_columns`, but permits mutating the column references.
2093    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2094    where
2095        F: FnMut(usize, &mut ColumnRef),
2096    {
2097        #[allow(deprecated)]
2098        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2099                                                               depth: usize|
2100         -> Result<(), ()> {
2101            e.visit_columns_mut(depth, f);
2102            Ok(())
2103        });
2104    }
2105
2106    /// Replaces any parameter references in the expression with the
2107    /// corresponding datum from `params`.
2108    pub fn bind_parameters(
2109        &mut self,
2110        scx: &StatementContext,
2111        lifetime: QueryLifetime,
2112        params: &Params,
2113    ) -> Result<(), PlanError> {
2114        #[allow(deprecated)]
2115        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2116            e.bind_parameters(scx, lifetime, params)
2117        })
2118    }
2119
2120    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2121        let mut contains_parameters = false;
2122        #[allow(deprecated)]
2123        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2124            if e.contains_parameters() {
2125                contains_parameters = true;
2126            }
2127            Ok::<(), PlanError>(())
2128        })?;
2129        Ok(contains_parameters)
2130    }
2131
2132    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2133    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2134        #[allow(deprecated)]
2135        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2136                                                               depth: usize|
2137         -> Result<(), ()> {
2138            e.splice_parameters(params, depth);
2139            Ok(())
2140        });
2141    }
2142
2143    /// Constructs a constant collection from specific rows and schema.
2144    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2145        let rows = rows
2146            .into_iter()
2147            .map(move |datums| Row::pack_slice(&datums))
2148            .collect();
2149        HirRelationExpr::Constant { rows, typ }
2150    }
2151
2152    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2153    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2154    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2155    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2156    /// finishing into a trivial finishing.
2157    pub fn finish_maintained(
2158        &mut self,
2159        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2160        group_size_hints: GroupSizeHints,
2161    ) {
2162        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2163            let old_finishing = mem::replace(
2164                finishing,
2165                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2166            );
2167            *self = HirRelationExpr::top_k(
2168                std::mem::replace(
2169                    self,
2170                    HirRelationExpr::Constant {
2171                        rows: vec![],
2172                        typ: SqlRelationType::new(Vec::new()),
2173                    },
2174                ),
2175                vec![],
2176                old_finishing.order_by,
2177                old_finishing.limit,
2178                old_finishing.offset,
2179                group_size_hints.limit_input_group_size,
2180            )
2181            .project(old_finishing.project);
2182        }
2183    }
2184
2185    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2186    ///
2187    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2188    /// parameter is not an HirScalarExpr anymore.)
2189    pub fn trivial_row_set_finishing_hir(
2190        arity: usize,
2191    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2192        RowSetFinishing {
2193            order_by: Vec::new(),
2194            limit: None,
2195            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2196            project: (0..arity).collect(),
2197        }
2198    }
2199
2200    /// True if the finishing does nothing to any result set.
2201    ///
2202    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2203    /// parameter is not an HirScalarExpr anymore.)
2204    pub fn is_trivial_row_set_finishing_hir(
2205        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2206        arity: usize,
2207    ) -> bool {
2208        rsf.limit.is_none()
2209            && rsf.order_by.is_empty()
2210            && rsf
2211                .offset
2212                .clone()
2213                .try_into_literal_int64()
2214                .is_ok_and(|o| o == 0)
2215            && rsf.project.iter().copied().eq(0..arity)
2216    }
2217
2218    /// The HirRelationExpr is considered potentially expensive if and only if
2219    /// at least one of the following conditions is true:
2220    ///
2221    ///  - It contains at least one CallTable or a Reduce operator.
2222    ///  - It contains at least one HirScalarExpr with a function call.
2223    ///
2224    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2225    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2226    pub fn could_run_expensive_function(&self) -> bool {
2227        let mut result = false;
2228        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2229            use HirRelationExpr::*;
2230            use HirScalarExpr::*;
2231
2232            self.visit_children(|scalar: &HirScalarExpr| {
2233                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2234                    result |= match scalar {
2235                        Column(..)
2236                        | Literal(..)
2237                        | CallUnmaterializable(..)
2238                        | If { .. }
2239                        | Parameter(..)
2240                        | Select(..)
2241                        | Exists(..) => false,
2242                        // Function calls are considered expensive
2243                        CallUnary { .. }
2244                        | CallBinary { .. }
2245                        | CallVariadic { .. }
2246                        | Windowing(..) => true,
2247                    };
2248                }) {
2249                    // Conservatively set `true` on RecursionLimitError.
2250                    result = true;
2251                }
2252            });
2253
2254            // CallTable has a table function; Reduce has an aggregate function.
2255            // Other constructs use MirScalarExpr to run a function
2256            result |= matches!(e, CallTable { .. } | Reduce { .. });
2257        }) {
2258            // Conservatively set `true` on RecursionLimitError.
2259            result = true;
2260        }
2261
2262        result
2263    }
2264
2265    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2266    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2267        let mut contains = false;
2268        self.visit_post(&mut |expr| {
2269            expr.visit_children(|expr: &HirScalarExpr| {
2270                contains = contains || expr.contains_temporal()
2271            })
2272        })?;
2273        Ok(contains)
2274    }
2275}
2276
2277impl CollectionPlan for HirRelationExpr {
2278    // !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2279    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2280    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2281        if let Self::Get {
2282            id: Id::Global(id), ..
2283        } = self
2284        {
2285            out.insert(*id);
2286        }
2287        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2288    }
2289}
2290
2291impl VisitChildren<Self> for HirRelationExpr {
2292    fn visit_children<F>(&self, mut f: F)
2293    where
2294        F: FnMut(&Self),
2295    {
2296        // subqueries of type HirRelationExpr might be wrapped in
2297        // Exists or Select variants within HirScalarExpr trees
2298        // attached at the current node, and we want to visit them as well
2299        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2300            #[allow(deprecated)]
2301            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2302                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2303                    f(expr.as_ref())
2304                }
2305                _ => (),
2306            });
2307        });
2308
2309        use HirRelationExpr::*;
2310        match self {
2311            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2312            Let {
2313                name: _,
2314                id: _,
2315                value,
2316                body,
2317            } => {
2318                f(value);
2319                f(body);
2320            }
2321            LetRec {
2322                limit: _,
2323                bindings,
2324                body,
2325            } => {
2326                for (_, _, value, _) in bindings.iter() {
2327                    f(value);
2328                }
2329                f(body);
2330            }
2331            Project { input, outputs: _ } => f(input),
2332            Map { input, scalars: _ } => {
2333                f(input);
2334            }
2335            CallTable { func: _, exprs: _ } => (),
2336            Filter {
2337                input,
2338                predicates: _,
2339            } => {
2340                f(input);
2341            }
2342            Join {
2343                left,
2344                right,
2345                on: _,
2346                kind: _,
2347            } => {
2348                f(left);
2349                f(right);
2350            }
2351            Reduce {
2352                input,
2353                group_key: _,
2354                aggregates: _,
2355                expected_group_size: _,
2356            } => {
2357                f(input);
2358            }
2359            Distinct { input }
2360            | TopK {
2361                input,
2362                group_key: _,
2363                order_key: _,
2364                limit: _,
2365                offset: _,
2366                expected_group_size: _,
2367            }
2368            | Negate { input }
2369            | Threshold { input } => {
2370                f(input);
2371            }
2372            Union { base, inputs } => {
2373                f(base);
2374                for input in inputs {
2375                    f(input);
2376                }
2377            }
2378        }
2379    }
2380
2381    fn visit_mut_children<F>(&mut self, mut f: F)
2382    where
2383        F: FnMut(&mut Self),
2384    {
2385        // subqueries of type HirRelationExpr might be wrapped in
2386        // Exists or Select variants within HirScalarExpr trees
2387        // attached at the current node, and we want to visit them as well
2388        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2389            #[allow(deprecated)]
2390            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2391                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2392                    f(expr.as_mut())
2393                }
2394                _ => (),
2395            });
2396        });
2397
2398        use HirRelationExpr::*;
2399        match self {
2400            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2401            Let {
2402                name: _,
2403                id: _,
2404                value,
2405                body,
2406            } => {
2407                f(value);
2408                f(body);
2409            }
2410            LetRec {
2411                limit: _,
2412                bindings,
2413                body,
2414            } => {
2415                for (_, _, value, _) in bindings.iter_mut() {
2416                    f(value);
2417                }
2418                f(body);
2419            }
2420            Project { input, outputs: _ } => f(input),
2421            Map { input, scalars: _ } => {
2422                f(input);
2423            }
2424            CallTable { func: _, exprs: _ } => (),
2425            Filter {
2426                input,
2427                predicates: _,
2428            } => {
2429                f(input);
2430            }
2431            Join {
2432                left,
2433                right,
2434                on: _,
2435                kind: _,
2436            } => {
2437                f(left);
2438                f(right);
2439            }
2440            Reduce {
2441                input,
2442                group_key: _,
2443                aggregates: _,
2444                expected_group_size: _,
2445            } => {
2446                f(input);
2447            }
2448            Distinct { input }
2449            | TopK {
2450                input,
2451                group_key: _,
2452                order_key: _,
2453                limit: _,
2454                offset: _,
2455                expected_group_size: _,
2456            }
2457            | Negate { input }
2458            | Threshold { input } => {
2459                f(input);
2460            }
2461            Union { base, inputs } => {
2462                f(base);
2463                for input in inputs {
2464                    f(input);
2465                }
2466            }
2467        }
2468    }
2469
2470    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2471    where
2472        F: FnMut(&Self) -> Result<(), E>,
2473        E: From<RecursionLimitError>,
2474    {
2475        // subqueries of type HirRelationExpr might be wrapped in
2476        // Exists or Select variants within HirScalarExpr trees
2477        // attached at the current node, and we want to visit them as well
2478        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2479            Visit::try_visit_post(expr, &mut |expr| match expr {
2480                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2481                    f(expr.as_ref())
2482                }
2483                _ => Ok(()),
2484            })
2485        })?;
2486
2487        use HirRelationExpr::*;
2488        match self {
2489            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2490            Let {
2491                name: _,
2492                id: _,
2493                value,
2494                body,
2495            } => {
2496                f(value)?;
2497                f(body)?;
2498            }
2499            LetRec {
2500                limit: _,
2501                bindings,
2502                body,
2503            } => {
2504                for (_, _, value, _) in bindings.iter() {
2505                    f(value)?;
2506                }
2507                f(body)?;
2508            }
2509            Project { input, outputs: _ } => f(input)?,
2510            Map { input, scalars: _ } => {
2511                f(input)?;
2512            }
2513            CallTable { func: _, exprs: _ } => (),
2514            Filter {
2515                input,
2516                predicates: _,
2517            } => {
2518                f(input)?;
2519            }
2520            Join {
2521                left,
2522                right,
2523                on: _,
2524                kind: _,
2525            } => {
2526                f(left)?;
2527                f(right)?;
2528            }
2529            Reduce {
2530                input,
2531                group_key: _,
2532                aggregates: _,
2533                expected_group_size: _,
2534            } => {
2535                f(input)?;
2536            }
2537            Distinct { input }
2538            | TopK {
2539                input,
2540                group_key: _,
2541                order_key: _,
2542                limit: _,
2543                offset: _,
2544                expected_group_size: _,
2545            }
2546            | Negate { input }
2547            | Threshold { input } => {
2548                f(input)?;
2549            }
2550            Union { base, inputs } => {
2551                f(base)?;
2552                for input in inputs {
2553                    f(input)?;
2554                }
2555            }
2556        }
2557        Ok(())
2558    }
2559
2560    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2561    where
2562        F: FnMut(&mut Self) -> Result<(), E>,
2563        E: From<RecursionLimitError>,
2564    {
2565        // subqueries of type HirRelationExpr might be wrapped in
2566        // Exists or Select variants within HirScalarExpr trees
2567        // attached at the current node, and we want to visit them as well
2568        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2569            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2570                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2571                    f(expr.as_mut())
2572                }
2573                _ => Ok(()),
2574            })
2575        })?;
2576
2577        use HirRelationExpr::*;
2578        match self {
2579            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2580            Let {
2581                name: _,
2582                id: _,
2583                value,
2584                body,
2585            } => {
2586                f(value)?;
2587                f(body)?;
2588            }
2589            LetRec {
2590                limit: _,
2591                bindings,
2592                body,
2593            } => {
2594                for (_, _, value, _) in bindings.iter_mut() {
2595                    f(value)?;
2596                }
2597                f(body)?;
2598            }
2599            Project { input, outputs: _ } => f(input)?,
2600            Map { input, scalars: _ } => {
2601                f(input)?;
2602            }
2603            CallTable { func: _, exprs: _ } => (),
2604            Filter {
2605                input,
2606                predicates: _,
2607            } => {
2608                f(input)?;
2609            }
2610            Join {
2611                left,
2612                right,
2613                on: _,
2614                kind: _,
2615            } => {
2616                f(left)?;
2617                f(right)?;
2618            }
2619            Reduce {
2620                input,
2621                group_key: _,
2622                aggregates: _,
2623                expected_group_size: _,
2624            } => {
2625                f(input)?;
2626            }
2627            Distinct { input }
2628            | TopK {
2629                input,
2630                group_key: _,
2631                order_key: _,
2632                limit: _,
2633                offset: _,
2634                expected_group_size: _,
2635            }
2636            | Negate { input }
2637            | Threshold { input } => {
2638                f(input)?;
2639            }
2640            Union { base, inputs } => {
2641                f(base)?;
2642                for input in inputs {
2643                    f(input)?;
2644                }
2645            }
2646        }
2647        Ok(())
2648    }
2649}
2650
2651impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2652    fn visit_children<F>(&self, mut f: F)
2653    where
2654        F: FnMut(&HirScalarExpr),
2655    {
2656        use HirRelationExpr::*;
2657        match self {
2658            Constant { rows: _, typ: _ }
2659            | Get { id: _, typ: _ }
2660            | Let {
2661                name: _,
2662                id: _,
2663                value: _,
2664                body: _,
2665            }
2666            | LetRec {
2667                limit: _,
2668                bindings: _,
2669                body: _,
2670            }
2671            | Project {
2672                input: _,
2673                outputs: _,
2674            } => (),
2675            Map { input: _, scalars } => {
2676                for scalar in scalars {
2677                    f(scalar);
2678                }
2679            }
2680            CallTable { func: _, exprs } => {
2681                for expr in exprs {
2682                    f(expr);
2683                }
2684            }
2685            Filter {
2686                input: _,
2687                predicates,
2688            } => {
2689                for predicate in predicates {
2690                    f(predicate);
2691                }
2692            }
2693            Join {
2694                left: _,
2695                right: _,
2696                on,
2697                kind: _,
2698            } => f(on),
2699            Reduce {
2700                input: _,
2701                group_key: _,
2702                aggregates,
2703                expected_group_size: _,
2704            } => {
2705                for aggregate in aggregates {
2706                    f(aggregate.expr.as_ref());
2707                }
2708            }
2709            TopK {
2710                input: _,
2711                group_key: _,
2712                order_key: _,
2713                limit,
2714                offset,
2715                expected_group_size: _,
2716            } => {
2717                if let Some(limit) = limit {
2718                    f(limit)
2719                }
2720                f(offset)
2721            }
2722            Distinct { input: _ }
2723            | Negate { input: _ }
2724            | Threshold { input: _ }
2725            | Union { base: _, inputs: _ } => (),
2726        }
2727    }
2728
2729    fn visit_mut_children<F>(&mut self, mut f: F)
2730    where
2731        F: FnMut(&mut HirScalarExpr),
2732    {
2733        use HirRelationExpr::*;
2734        match self {
2735            Constant { rows: _, typ: _ }
2736            | Get { id: _, typ: _ }
2737            | Let {
2738                name: _,
2739                id: _,
2740                value: _,
2741                body: _,
2742            }
2743            | LetRec {
2744                limit: _,
2745                bindings: _,
2746                body: _,
2747            }
2748            | Project {
2749                input: _,
2750                outputs: _,
2751            } => (),
2752            Map { input: _, scalars } => {
2753                for scalar in scalars {
2754                    f(scalar);
2755                }
2756            }
2757            CallTable { func: _, exprs } => {
2758                for expr in exprs {
2759                    f(expr);
2760                }
2761            }
2762            Filter {
2763                input: _,
2764                predicates,
2765            } => {
2766                for predicate in predicates {
2767                    f(predicate);
2768                }
2769            }
2770            Join {
2771                left: _,
2772                right: _,
2773                on,
2774                kind: _,
2775            } => f(on),
2776            Reduce {
2777                input: _,
2778                group_key: _,
2779                aggregates,
2780                expected_group_size: _,
2781            } => {
2782                for aggregate in aggregates {
2783                    f(aggregate.expr.as_mut());
2784                }
2785            }
2786            TopK {
2787                input: _,
2788                group_key: _,
2789                order_key: _,
2790                limit,
2791                offset,
2792                expected_group_size: _,
2793            } => {
2794                if let Some(limit) = limit {
2795                    f(limit)
2796                }
2797                f(offset)
2798            }
2799            Distinct { input: _ }
2800            | Negate { input: _ }
2801            | Threshold { input: _ }
2802            | Union { base: _, inputs: _ } => (),
2803        }
2804    }
2805
2806    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2807    where
2808        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2809        E: From<RecursionLimitError>,
2810    {
2811        use HirRelationExpr::*;
2812        match self {
2813            Constant { rows: _, typ: _ }
2814            | Get { id: _, typ: _ }
2815            | Let {
2816                name: _,
2817                id: _,
2818                value: _,
2819                body: _,
2820            }
2821            | LetRec {
2822                limit: _,
2823                bindings: _,
2824                body: _,
2825            }
2826            | Project {
2827                input: _,
2828                outputs: _,
2829            } => (),
2830            Map { input: _, scalars } => {
2831                for scalar in scalars {
2832                    f(scalar)?;
2833                }
2834            }
2835            CallTable { func: _, exprs } => {
2836                for expr in exprs {
2837                    f(expr)?;
2838                }
2839            }
2840            Filter {
2841                input: _,
2842                predicates,
2843            } => {
2844                for predicate in predicates {
2845                    f(predicate)?;
2846                }
2847            }
2848            Join {
2849                left: _,
2850                right: _,
2851                on,
2852                kind: _,
2853            } => f(on)?,
2854            Reduce {
2855                input: _,
2856                group_key: _,
2857                aggregates,
2858                expected_group_size: _,
2859            } => {
2860                for aggregate in aggregates {
2861                    f(aggregate.expr.as_ref())?;
2862                }
2863            }
2864            TopK {
2865                input: _,
2866                group_key: _,
2867                order_key: _,
2868                limit,
2869                offset,
2870                expected_group_size: _,
2871            } => {
2872                if let Some(limit) = limit {
2873                    f(limit)?
2874                }
2875                f(offset)?
2876            }
2877            Distinct { input: _ }
2878            | Negate { input: _ }
2879            | Threshold { input: _ }
2880            | Union { base: _, inputs: _ } => (),
2881        }
2882        Ok(())
2883    }
2884
2885    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2886    where
2887        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2888        E: From<RecursionLimitError>,
2889    {
2890        use HirRelationExpr::*;
2891        match self {
2892            Constant { rows: _, typ: _ }
2893            | Get { id: _, typ: _ }
2894            | Let {
2895                name: _,
2896                id: _,
2897                value: _,
2898                body: _,
2899            }
2900            | LetRec {
2901                limit: _,
2902                bindings: _,
2903                body: _,
2904            }
2905            | Project {
2906                input: _,
2907                outputs: _,
2908            } => (),
2909            Map { input: _, scalars } => {
2910                for scalar in scalars {
2911                    f(scalar)?;
2912                }
2913            }
2914            CallTable { func: _, exprs } => {
2915                for expr in exprs {
2916                    f(expr)?;
2917                }
2918            }
2919            Filter {
2920                input: _,
2921                predicates,
2922            } => {
2923                for predicate in predicates {
2924                    f(predicate)?;
2925                }
2926            }
2927            Join {
2928                left: _,
2929                right: _,
2930                on,
2931                kind: _,
2932            } => f(on)?,
2933            Reduce {
2934                input: _,
2935                group_key: _,
2936                aggregates,
2937                expected_group_size: _,
2938            } => {
2939                for aggregate in aggregates {
2940                    f(aggregate.expr.as_mut())?;
2941                }
2942            }
2943            TopK {
2944                input: _,
2945                group_key: _,
2946                order_key: _,
2947                limit,
2948                offset,
2949                expected_group_size: _,
2950            } => {
2951                if let Some(limit) = limit {
2952                    f(limit)?
2953                }
2954                f(offset)?
2955            }
2956            Distinct { input: _ }
2957            | Negate { input: _ }
2958            | Threshold { input: _ }
2959            | Union { base: _, inputs: _ } => (),
2960        }
2961        Ok(())
2962    }
2963}
2964
2965impl HirScalarExpr {
2966    pub fn name(&self) -> Option<Arc<str>> {
2967        use HirScalarExpr::*;
2968        match self {
2969            Column(_, name)
2970            | Parameter(_, name)
2971            | Literal(_, _, name)
2972            | CallUnmaterializable(_, name)
2973            | CallUnary { name, .. }
2974            | CallBinary { name, .. }
2975            | CallVariadic { name, .. }
2976            | If { name, .. }
2977            | Exists(_, name)
2978            | Select(_, name)
2979            | Windowing(_, name) => name.0.clone(),
2980        }
2981    }
2982
2983    /// Replaces any parameter references in the expression with the
2984    /// corresponding datum in `params`.
2985    pub fn bind_parameters(
2986        &mut self,
2987        scx: &StatementContext,
2988        lifetime: QueryLifetime,
2989        params: &Params,
2990    ) -> Result<(), PlanError> {
2991        #[allow(deprecated)]
2992        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
2993            if let HirScalarExpr::Parameter(n, name) = e {
2994                let datum = match params.datums.iter().nth(*n - 1) {
2995                    None => return Err(PlanError::UnknownParameter(*n)),
2996                    Some(datum) => datum,
2997                };
2998                let scalar_type = &params.execute_types[*n - 1];
2999                let row = Row::pack([datum]);
3000                let column_type = scalar_type.clone().nullable(datum.is_null());
3001
3002                let name = if let Some(name) = &name.0 {
3003                    Some(Arc::clone(name))
3004                } else {
3005                    Some(Arc::from(format!("${n}")))
3006                };
3007
3008                let qcx = QueryContext::root(scx, lifetime);
3009                let ecx = execute_expr_context(&qcx);
3010
3011                *e = plan_cast(
3012                    &ecx,
3013                    *EXECUTE_CAST_CONTEXT,
3014                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3015                    &params.expected_types[*n - 1],
3016                )
3017                .expect("checked in plan_params");
3018            }
3019            Ok(())
3020        })
3021    }
3022
3023    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3024    /// replaced with the corresponding expression fragment from `params` rather
3025    /// than a datum.
3026    ///
3027    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3028    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3029    /// in `self` that refer to invalid indices of `params` will cause a panic.
3030    ///
3031    /// Column references in parameters will be corrected to account for the
3032    /// depth at which they are spliced.
3033    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3034        #[allow(deprecated)]
3035        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3036                                                        e: &mut HirScalarExpr|
3037         -> Result<(), ()> {
3038            if let HirScalarExpr::Parameter(i, _name) = e {
3039                *e = params[*i - 1].clone();
3040                // Correct any column references in the parameter expression for
3041                // its new depth.
3042                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3043                    if col.level >= d {
3044                        col.level += depth
3045                    }
3046                });
3047            }
3048            Ok(())
3049        });
3050    }
3051
3052    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3053    pub fn contains_temporal(&self) -> bool {
3054        let mut contains = false;
3055        #[allow(deprecated)]
3056        self.visit_post_nolimit(&mut |e| {
3057            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3058                contains = true;
3059            }
3060        });
3061        contains
3062    }
3063
3064    /// Constructs an unnamed column reference in the current scope.
3065    /// Use [`HirScalarExpr::named_column`] when a name is known.
3066    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3067    pub fn column(index: usize) -> HirScalarExpr {
3068        HirScalarExpr::Column(
3069            ColumnRef {
3070                level: 0,
3071                column: index,
3072            },
3073            TreatAsEqual(None),
3074        )
3075    }
3076
3077    /// Constructs an unnamed column reference.
3078    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3079        HirScalarExpr::Column(cr, TreatAsEqual(None))
3080    }
3081
3082    /// Constructs a named column reference.
3083    /// Names are interned by a `NameManager`.
3084    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3085        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3086    }
3087
3088    pub fn parameter(n: usize) -> HirScalarExpr {
3089        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3090    }
3091
3092    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3093        let col_type = scalar_type.nullable(datum.is_null());
3094        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3095        let row = Row::pack([datum]);
3096        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3097    }
3098
3099    pub fn literal_true() -> HirScalarExpr {
3100        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3101    }
3102
3103    pub fn literal_false() -> HirScalarExpr {
3104        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3105    }
3106
3107    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3108        HirScalarExpr::literal(Datum::Null, scalar_type)
3109    }
3110
3111    pub fn literal_1d_array(
3112        datums: Vec<Datum>,
3113        element_scalar_type: SqlScalarType,
3114    ) -> Result<HirScalarExpr, PlanError> {
3115        let scalar_type = match element_scalar_type {
3116            SqlScalarType::Array(_) => {
3117                sql_bail!("cannot build array from array type");
3118            }
3119            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3120        };
3121
3122        let mut row = Row::default();
3123        row.packer()
3124            .try_push_array(
3125                &[ArrayDimension {
3126                    lower_bound: 1,
3127                    length: datums.len(),
3128                }],
3129                datums,
3130            )
3131            .expect("array constructed to be valid");
3132
3133        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3134    }
3135
3136    pub fn as_literal(&self) -> Option<Datum<'_>> {
3137        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3138            Some(row.unpack_first())
3139        } else {
3140            None
3141        }
3142    }
3143
3144    pub fn is_literal_true(&self) -> bool {
3145        Some(Datum::True) == self.as_literal()
3146    }
3147
3148    pub fn is_literal_false(&self) -> bool {
3149        Some(Datum::False) == self.as_literal()
3150    }
3151
3152    pub fn is_literal_null(&self) -> bool {
3153        Some(Datum::Null) == self.as_literal()
3154    }
3155
3156    /// Return true iff `self` consists only of literals, materializable function calls, and
3157    /// if-else statements.
3158    pub fn is_constant(&self) -> bool {
3159        let mut worklist = vec![self];
3160        while let Some(expr) = worklist.pop() {
3161            match expr {
3162                Self::Literal(..) => {
3163                    // leaf node, do nothing
3164                }
3165                Self::CallUnary { expr, .. } => {
3166                    worklist.push(expr);
3167                }
3168                Self::CallBinary {
3169                    func: _,
3170                    expr1,
3171                    expr2,
3172                    name: _,
3173                } => {
3174                    worklist.push(expr1);
3175                    worklist.push(expr2);
3176                }
3177                Self::CallVariadic {
3178                    func: _,
3179                    exprs,
3180                    name: _,
3181                } => {
3182                    worklist.extend(exprs.iter());
3183                }
3184                // (CallUnmaterializable is not allowed)
3185                Self::If {
3186                    cond,
3187                    then,
3188                    els,
3189                    name: _,
3190                } => {
3191                    worklist.push(cond);
3192                    worklist.push(then);
3193                    worklist.push(els);
3194                }
3195                _ => {
3196                    return false; // Any other node makes `self` non-constant.
3197                }
3198            }
3199        }
3200        true
3201    }
3202
3203    pub fn call_unary(self, func: UnaryFunc) -> Self {
3204        HirScalarExpr::CallUnary {
3205            func,
3206            expr: Box::new(self),
3207            name: NameMetadata::default(),
3208        }
3209    }
3210
3211    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
3212        HirScalarExpr::CallBinary {
3213            func,
3214            expr1: Box::new(self),
3215            expr2: Box::new(other),
3216            name: NameMetadata::default(),
3217        }
3218    }
3219
3220    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3221        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3222    }
3223
3224    pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3225        HirScalarExpr::CallVariadic {
3226            func,
3227            exprs,
3228            name: NameMetadata::default(),
3229        }
3230    }
3231
3232    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3233        HirScalarExpr::If {
3234            cond: Box::new(cond),
3235            then: Box::new(then),
3236            els: Box::new(els),
3237            name: NameMetadata::default(),
3238        }
3239    }
3240
3241    pub fn windowing(expr: WindowExpr) -> Self {
3242        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3243    }
3244
3245    pub fn or(self, other: Self) -> Self {
3246        HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3247    }
3248
3249    pub fn and(self, other: Self) -> Self {
3250        HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3251    }
3252
3253    pub fn not(self) -> Self {
3254        self.call_unary(UnaryFunc::Not(func::Not))
3255    }
3256
3257    pub fn call_is_null(self) -> Self {
3258        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3259    }
3260
3261    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3262    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3263        match args.len() {
3264            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3265            1 => args.swap_remove(0),
3266            _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3267        }
3268    }
3269
3270    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3271    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3272        match args.len() {
3273            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3274            1 => args.swap_remove(0),
3275            _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3276        }
3277    }
3278
3279    pub fn take(&mut self) -> Self {
3280        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3281    }
3282
3283    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3284    /// Visits the column references in this scalar expression.
3285    ///
3286    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3287    /// which will be incremented with each subquery entered and presented to the supplied
3288    /// function `f`.
3289    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3290    where
3291        F: FnMut(usize, &ColumnRef),
3292    {
3293        #[allow(deprecated)]
3294        let _ = self.visit_recursively(depth, &mut |depth: usize,
3295                                                    e: &HirScalarExpr|
3296         -> Result<(), ()> {
3297            if let HirScalarExpr::Column(col, _name) = e {
3298                f(depth, col)
3299            }
3300            Ok(())
3301        });
3302    }
3303
3304    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3305    /// Like `visit_columns`, but permits mutating the column references.
3306    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3307    where
3308        F: FnMut(usize, &mut ColumnRef),
3309    {
3310        #[allow(deprecated)]
3311        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3312                                                        e: &mut HirScalarExpr|
3313         -> Result<(), ()> {
3314            if let HirScalarExpr::Column(col, _name) = e {
3315                f(depth, col)
3316            }
3317            Ok(())
3318        });
3319    }
3320
3321    /// Visits those column references in this scalar expression that refer to the root
3322    /// level. These include column references that are at the root level, as well as column
3323    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3324    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3325    /// "root level" to be `self`'s level.)
3326    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3327    where
3328        F: FnMut(usize),
3329    {
3330        #[allow(deprecated)]
3331        let _ = self.visit_recursively(0, &mut |depth: usize,
3332                                                e: &HirScalarExpr|
3333         -> Result<(), ()> {
3334            if let HirScalarExpr::Column(col, _name) = e {
3335                if col.level == depth {
3336                    f(col.column)
3337                }
3338            }
3339            Ok(())
3340        });
3341    }
3342
3343    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3344    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3345    where
3346        F: FnMut(&mut usize),
3347    {
3348        #[allow(deprecated)]
3349        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3350                                                    e: &mut HirScalarExpr|
3351         -> Result<(), ()> {
3352            if let HirScalarExpr::Column(col, _name) = e {
3353                if col.level == depth {
3354                    f(&mut col.column)
3355                }
3356            }
3357            Ok(())
3358        });
3359    }
3360
3361    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3362    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3363    /// in them. It takes the current depth of the expression and increases it when
3364    /// entering a subquery.
3365    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3366    where
3367        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3368    {
3369        match self {
3370            HirScalarExpr::Literal(..)
3371            | HirScalarExpr::Parameter(..)
3372            | HirScalarExpr::CallUnmaterializable(..)
3373            | HirScalarExpr::Column(..) => (),
3374            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3375            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3376                expr1.visit_recursively(depth, f)?;
3377                expr2.visit_recursively(depth, f)?;
3378            }
3379            HirScalarExpr::CallVariadic { exprs, .. } => {
3380                for expr in exprs {
3381                    expr.visit_recursively(depth, f)?;
3382                }
3383            }
3384            HirScalarExpr::If {
3385                cond,
3386                then,
3387                els,
3388                name: _,
3389            } => {
3390                cond.visit_recursively(depth, f)?;
3391                then.visit_recursively(depth, f)?;
3392                els.visit_recursively(depth, f)?;
3393            }
3394            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3395                #[allow(deprecated)]
3396                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3397                    e.visit_recursively(depth, f)
3398                })?;
3399            }
3400            HirScalarExpr::Windowing(expr, _name) => {
3401                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3402            }
3403        }
3404        f(depth, self)
3405    }
3406
3407    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3408    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3409    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3410    where
3411        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3412    {
3413        match self {
3414            HirScalarExpr::Literal(..)
3415            | HirScalarExpr::Parameter(..)
3416            | HirScalarExpr::CallUnmaterializable(..)
3417            | HirScalarExpr::Column(..) => (),
3418            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3419            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3420                expr1.visit_recursively_mut(depth, f)?;
3421                expr2.visit_recursively_mut(depth, f)?;
3422            }
3423            HirScalarExpr::CallVariadic { exprs, .. } => {
3424                for expr in exprs {
3425                    expr.visit_recursively_mut(depth, f)?;
3426                }
3427            }
3428            HirScalarExpr::If {
3429                cond,
3430                then,
3431                els,
3432                name: _,
3433            } => {
3434                cond.visit_recursively_mut(depth, f)?;
3435                then.visit_recursively_mut(depth, f)?;
3436                els.visit_recursively_mut(depth, f)?;
3437            }
3438            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3439                #[allow(deprecated)]
3440                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3441                    e.visit_recursively_mut(depth, f)
3442                })?;
3443            }
3444            HirScalarExpr::Windowing(expr, _name) => {
3445                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3446            }
3447        }
3448        f(depth, self)
3449    }
3450
3451    /// Attempts to simplify self into a literal.
3452    ///
3453    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3454    /// an evaluation error occurs during simplification, or if self contains
3455    /// - a subquery
3456    /// - a column reference to an outer level
3457    /// - a parameter
3458    /// - a window function call
3459    fn simplify_to_literal(self) -> Option<Row> {
3460        let mut expr = self.lower_uncorrelated().ok()?;
3461        expr.reduce(&[]);
3462        match expr {
3463            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3464            _ => None,
3465        }
3466    }
3467
3468    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3469    /// or an evaluation error occurs during simplification), it returns
3470    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3471    ///
3472    /// The returned error is an _internal_ error if the expression contains
3473    /// - a subquery
3474    /// - a column reference to an outer level
3475    /// - a parameter
3476    /// - a window function call
3477    ///
3478    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3479    /// msg.
3480    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3481        let mut expr = self.lower_uncorrelated().map_err(|err| {
3482            PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3483        })?;
3484        expr.reduce(&[]);
3485        match expr {
3486            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3487            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3488                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3489            ),
3490            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3491                "Not a constant".to_string(),
3492            )),
3493        }
3494    }
3495
3496    /// Attempts to simplify this expression to a literal 64-bit integer.
3497    ///
3498    /// Returns `None` if this expression cannot be simplified, e.g. because it
3499    /// contains non-literal values.
3500    ///
3501    /// # Panics
3502    ///
3503    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3504    pub fn into_literal_int64(self) -> Option<i64> {
3505        self.simplify_to_literal().and_then(|row| {
3506            let datum = row.unpack_first();
3507            if datum.is_null() {
3508                None
3509            } else {
3510                Some(datum.unwrap_int64())
3511            }
3512        })
3513    }
3514
3515    /// Attempts to simplify this expression to a literal string.
3516    ///
3517    /// Returns `None` if this expression cannot be simplified, e.g. because it
3518    /// contains non-literal values.
3519    ///
3520    /// # Panics
3521    ///
3522    /// Panics if this expression does not have type [`SqlScalarType::String`].
3523    pub fn into_literal_string(self) -> Option<String> {
3524        self.simplify_to_literal().and_then(|row| {
3525            let datum = row.unpack_first();
3526            if datum.is_null() {
3527                None
3528            } else {
3529                Some(datum.unwrap_str().to_owned())
3530            }
3531        })
3532    }
3533
3534    /// Attempts to simplify this expression to a literal MzTimestamp.
3535    ///
3536    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3537    /// simplified, e.g. because it contains non-literal values or a cast fails.
3538    ///
3539    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3540    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3541    ///
3542    /// # Panics
3543    ///
3544    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3545    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3546        self.simplify_to_literal().and_then(|row| {
3547            let datum = row.unpack_first();
3548            if datum.is_null() {
3549                None
3550            } else {
3551                Some(datum.unwrap_mz_timestamp())
3552            }
3553        })
3554    }
3555
3556    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3557    /// returns it as an i64.
3558    ///
3559    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3560    /// - it's not a constant expression (as determined by `is_constant`)
3561    /// - evaluates to null
3562    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3563    ///
3564    /// # Panics
3565    ///
3566    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3567    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3568        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3569        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3570        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3571        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3572        // preconditions also of all the other into_literal_... functions.)
3573        if !self.is_constant() {
3574            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3575                "Expected a constant expression, got {}",
3576                self
3577            )));
3578        }
3579        self.clone()
3580            .simplify_to_literal_with_result()
3581            .and_then(|row| {
3582                let datum = row.unpack_first();
3583                if datum.is_null() {
3584                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3585                        "Expected an expression that evaluates to a non-null value, got {}",
3586                        self
3587                    )))
3588                } else {
3589                    Ok(datum.unwrap_int64())
3590                }
3591            })
3592    }
3593
3594    pub fn contains_parameters(&self) -> bool {
3595        let mut contains_parameters = false;
3596        #[allow(deprecated)]
3597        let _ = self.visit_recursively(0, &mut |_depth: usize,
3598                                                expr: &HirScalarExpr|
3599         -> Result<(), ()> {
3600            if let HirScalarExpr::Parameter(..) = expr {
3601                contains_parameters = true;
3602            }
3603            Ok(())
3604        });
3605        contains_parameters
3606    }
3607}
3608
3609impl VisitChildren<Self> for HirScalarExpr {
3610    fn visit_children<F>(&self, mut f: F)
3611    where
3612        F: FnMut(&Self),
3613    {
3614        use HirScalarExpr::*;
3615        match self {
3616            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3617            CallUnary { expr, .. } => f(expr),
3618            CallBinary { expr1, expr2, .. } => {
3619                f(expr1);
3620                f(expr2);
3621            }
3622            CallVariadic { exprs, .. } => {
3623                for expr in exprs {
3624                    f(expr);
3625                }
3626            }
3627            If {
3628                cond,
3629                then,
3630                els,
3631                name: _,
3632            } => {
3633                f(cond);
3634                f(then);
3635                f(els);
3636            }
3637            Exists(..) | Select(..) => (),
3638            Windowing(expr, _name) => expr.visit_children(f),
3639        }
3640    }
3641
3642    fn visit_mut_children<F>(&mut self, mut f: F)
3643    where
3644        F: FnMut(&mut Self),
3645    {
3646        use HirScalarExpr::*;
3647        match self {
3648            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3649            CallUnary { expr, .. } => f(expr),
3650            CallBinary { expr1, expr2, .. } => {
3651                f(expr1);
3652                f(expr2);
3653            }
3654            CallVariadic { exprs, .. } => {
3655                for expr in exprs {
3656                    f(expr);
3657                }
3658            }
3659            If {
3660                cond,
3661                then,
3662                els,
3663                name: _,
3664            } => {
3665                f(cond);
3666                f(then);
3667                f(els);
3668            }
3669            Exists(..) | Select(..) => (),
3670            Windowing(expr, _name) => expr.visit_mut_children(f),
3671        }
3672    }
3673
3674    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3675    where
3676        F: FnMut(&Self) -> Result<(), E>,
3677        E: From<RecursionLimitError>,
3678    {
3679        use HirScalarExpr::*;
3680        match self {
3681            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3682            CallUnary { expr, .. } => f(expr)?,
3683            CallBinary { expr1, expr2, .. } => {
3684                f(expr1)?;
3685                f(expr2)?;
3686            }
3687            CallVariadic { exprs, .. } => {
3688                for expr in exprs {
3689                    f(expr)?;
3690                }
3691            }
3692            If {
3693                cond,
3694                then,
3695                els,
3696                name: _,
3697            } => {
3698                f(cond)?;
3699                f(then)?;
3700                f(els)?;
3701            }
3702            Exists(..) | Select(..) => (),
3703            Windowing(expr, _name) => expr.try_visit_children(f)?,
3704        }
3705        Ok(())
3706    }
3707
3708    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3709    where
3710        F: FnMut(&mut Self) -> Result<(), E>,
3711        E: From<RecursionLimitError>,
3712    {
3713        use HirScalarExpr::*;
3714        match self {
3715            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3716            CallUnary { expr, .. } => f(expr)?,
3717            CallBinary { expr1, expr2, .. } => {
3718                f(expr1)?;
3719                f(expr2)?;
3720            }
3721            CallVariadic { exprs, .. } => {
3722                for expr in exprs {
3723                    f(expr)?;
3724                }
3725            }
3726            If {
3727                cond,
3728                then,
3729                els,
3730                name: _,
3731            } => {
3732                f(cond)?;
3733                f(then)?;
3734                f(els)?;
3735            }
3736            Exists(..) | Select(..) => (),
3737            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3738        }
3739        Ok(())
3740    }
3741}
3742
3743impl AbstractExpr for HirScalarExpr {
3744    type Type = SqlColumnType;
3745
3746    fn typ(
3747        &self,
3748        outers: &[SqlRelationType],
3749        inner: &SqlRelationType,
3750        params: &BTreeMap<usize, SqlScalarType>,
3751    ) -> Self::Type {
3752        stack::maybe_grow(|| match self {
3753            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3754                if *level == 0 {
3755                    inner.column_types[*column].clone()
3756                } else {
3757                    outers[*level - 1].column_types[*column].clone()
3758                }
3759            }
3760            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3761            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3762            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3763            HirScalarExpr::CallUnary {
3764                expr,
3765                func,
3766                name: _,
3767            } => func.output_type(expr.typ(outers, inner, params)),
3768            HirScalarExpr::CallBinary {
3769                expr1,
3770                expr2,
3771                func,
3772                name: _,
3773            } => func.output_type(
3774                expr1.typ(outers, inner, params),
3775                expr2.typ(outers, inner, params),
3776            ),
3777            HirScalarExpr::CallVariadic {
3778                exprs,
3779                func,
3780                name: _,
3781            } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3782            HirScalarExpr::If {
3783                cond: _,
3784                then,
3785                els,
3786                name: _,
3787            } => {
3788                let then_type = then.typ(outers, inner, params);
3789                let else_type = els.typ(outers, inner, params);
3790                then_type.union(&else_type).unwrap()
3791            }
3792            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3793            HirScalarExpr::Select(expr, _name) => {
3794                let mut outers = outers.to_vec();
3795                outers.insert(0, inner.clone());
3796                expr.typ(&outers, params)
3797                    .column_types
3798                    .into_element()
3799                    .nullable(true)
3800            }
3801            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3802        })
3803    }
3804}
3805
3806impl AggregateExpr {
3807    pub fn typ(
3808        &self,
3809        outers: &[SqlRelationType],
3810        inner: &SqlRelationType,
3811        params: &BTreeMap<usize, SqlScalarType>,
3812    ) -> SqlColumnType {
3813        self.func.output_type(self.expr.typ(outers, inner, params))
3814    }
3815
3816    /// Returns whether the expression is COUNT(*) or not.  Note that
3817    /// when we define the count builtin in sql::func, we convert
3818    /// COUNT(*) to COUNT(true), making it indistinguishable from
3819    /// literal COUNT(true), but we prefer to consider this as the
3820    /// former.
3821    ///
3822    /// (MIR has the same `is_count_asterisk`.)
3823    pub fn is_count_asterisk(&self) -> bool {
3824        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3825    }
3826}