Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50    type Relation = HirRelationExpr;
51    type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56        if *all {
57            let rhs = rhs.negate();
58            HirRelationExpr::union(lhs, rhs).threshold()
59        } else {
60            let lhs = lhs.distinct();
61            let rhs = rhs.distinct().negate();
62            HirRelationExpr::union(lhs, rhs).threshold()
63        }
64    }
65
66    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67        let mut result = None;
68
69        use HirRelationExpr::*;
70        if let Threshold { input } = expr {
71            if let Union { base: lhs, inputs } = input.as_ref() {
72                if let [rhs] = &inputs[..] {
73                    if let Negate { input: rhs } = rhs {
74                        match (lhs.as_ref(), rhs.as_ref()) {
75                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
76                                let all = false;
77                                let lhs = lhs.as_ref();
78                                let rhs = rhs.as_ref();
79                                result = Some(Except { all, lhs, rhs })
80                            }
81                            (lhs, rhs) => {
82                                let all = true;
83                                result = Some(Except { all, lhs, rhs })
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        result
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
97pub enum HirRelationExpr {
98    Constant {
99        rows: Vec<Row>,
100        typ: SqlRelationType,
101    },
102    Get {
103        id: mz_expr::Id,
104        typ: SqlRelationType,
105    },
106    /// Mutually recursive CTE
107    LetRec {
108        /// Maximum number of iterations to evaluate. If None, then there is no limit.
109        limit: Option<LetRecLimit>,
110        /// List of bindings all of which are in scope of each other.
111        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
112        /// Result of the AST node.
113        body: Box<HirRelationExpr>,
114    },
115    /// CTE
116    Let {
117        name: String,
118        /// The identifier to be used in `Get` variants to retrieve `value`.
119        id: mz_expr::LocalId,
120        /// The collection to be bound to `name`.
121        value: Box<HirRelationExpr>,
122        /// The result of the `Let`, evaluated with `name` bound to `value`.
123        body: Box<HirRelationExpr>,
124    },
125    Project {
126        input: Box<HirRelationExpr>,
127        outputs: Vec<usize>,
128    },
129    Map {
130        input: Box<HirRelationExpr>,
131        scalars: Vec<HirScalarExpr>,
132    },
133    CallTable {
134        func: TableFunc,
135        exprs: Vec<HirScalarExpr>,
136    },
137    Filter {
138        input: Box<HirRelationExpr>,
139        predicates: Vec<HirScalarExpr>,
140    },
141    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
142    /// joins away into more primitive exprs
143    Join {
144        left: Box<HirRelationExpr>,
145        right: Box<HirRelationExpr>,
146        on: HirScalarExpr,
147        kind: JoinKind,
148    },
149    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
150    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
151    /// rows
152    Reduce {
153        input: Box<HirRelationExpr>,
154        group_key: Vec<usize>,
155        aggregates: Vec<AggregateExpr>,
156        expected_group_size: Option<u64>,
157    },
158    Distinct {
159        input: Box<HirRelationExpr>,
160    },
161    /// Groups and orders within each group, limiting output.
162    TopK {
163        /// The source collection.
164        input: Box<HirRelationExpr>,
165        /// Column indices used to form groups.
166        group_key: Vec<usize>,
167        /// Column indices used to order rows within groups.
168        order_key: Vec<ColumnOrder>,
169        /// Number of records to retain.
170        /// It is of SqlScalarType::Int64.
171        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
172        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
173        /// is better for Postgres compat. This is because if there is a $1 here, then when external
174        /// tools `describe` the prepared statement, they discover this type. If what they find
175        /// were UInt64, then they might have trouble calling the prepared statement, because the
176        /// unsigned types are non-standard, and also don't exist even in Postgres.)
177        limit: Option<HirScalarExpr>,
178        /// Number of records to skip.
179        /// It is of SqlScalarType::Int64.
180        /// This can contain parameters at first, but by the time we reach lowering, this should
181        /// already be simply a Literal.
182        offset: HirScalarExpr,
183        /// User-supplied hint: how many rows will have the same group key.
184        expected_group_size: Option<u64>,
185    },
186    Negate {
187        input: Box<HirRelationExpr>,
188    },
189    /// Keep rows from a dataflow where the row counts are positive.
190    Threshold {
191        input: Box<HirRelationExpr>,
192    },
193    Union {
194        base: Box<HirRelationExpr>,
195        inputs: Vec<HirRelationExpr>,
196    },
197}
198
199/// Stored column metadata.
200pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
204pub enum HirScalarExpr {
205    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
206    /// variable could refer to a column of the current input, or to a column of an outer relation.
207    /// We use ColumnRef to denote the difference.
208    Column(ColumnRef, NameMetadata),
209    Parameter(usize, NameMetadata),
210    Literal(Row, SqlColumnType, NameMetadata),
211    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212    CallUnary {
213        func: UnaryFunc,
214        expr: Box<HirScalarExpr>,
215        name: NameMetadata,
216    },
217    CallBinary {
218        func: BinaryFunc,
219        expr1: Box<HirScalarExpr>,
220        expr2: Box<HirScalarExpr>,
221        name: NameMetadata,
222    },
223    CallVariadic {
224        func: VariadicFunc,
225        exprs: Vec<HirScalarExpr>,
226        name: NameMetadata,
227    },
228    If {
229        cond: Box<HirScalarExpr>,
230        then: Box<HirScalarExpr>,
231        els: Box<HirScalarExpr>,
232        name: NameMetadata,
233    },
234    /// Returns true if `expr` returns any rows
235    Exists(Box<HirRelationExpr>, NameMetadata),
236    /// Given `expr` with arity 1. If expr returns:
237    /// * 0 rows, return NULL
238    /// * 1 row, return the value of that row
239    /// * >1 rows, we return an error
240    Select(Box<HirRelationExpr>, NameMetadata),
241    Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245/// Represents the invocation of a window function over an optional partitioning with an optional
246/// order.
247pub struct WindowExpr {
248    pub func: WindowExprType,
249    pub partition_by: Vec<HirScalarExpr>,
250    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
251    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
252    ///    above,
253    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
254    /// These are separated because they are used in different places: the outer `order_by` is used
255    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
256    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
257    /// (`WindowExpr` exists only in HIR, but not in MIR.)
258    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
259    /// lowering, and not to original input columns.
260    pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265    where
266        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267    {
268        #[allow(deprecated)]
269        self.func.visit_expressions(f)?;
270        for expr in self.partition_by.iter() {
271            f(expr)?;
272        }
273        for expr in self.order_by.iter() {
274            f(expr)?;
275        }
276        Ok(())
277    }
278
279    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280    where
281        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282    {
283        #[allow(deprecated)]
284        self.func.visit_expressions_mut(f)?;
285        for expr in self.partition_by.iter_mut() {
286            f(expr)?;
287        }
288        for expr in self.order_by.iter_mut() {
289            f(expr)?;
290        }
291        Ok(())
292    }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296    fn visit_children<F>(&self, mut f: F)
297    where
298        F: FnMut(&HirScalarExpr),
299    {
300        self.func.visit_children(&mut f);
301        for expr in self.partition_by.iter() {
302            f(expr);
303        }
304        for expr in self.order_by.iter() {
305            f(expr);
306        }
307    }
308
309    fn visit_mut_children<F>(&mut self, mut f: F)
310    where
311        F: FnMut(&mut HirScalarExpr),
312    {
313        self.func.visit_mut_children(&mut f);
314        for expr in self.partition_by.iter_mut() {
315            f(expr);
316        }
317        for expr in self.order_by.iter_mut() {
318            f(expr);
319        }
320    }
321
322    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323    where
324        F: FnMut(&HirScalarExpr) -> Result<(), E>,
325        E: From<RecursionLimitError>,
326    {
327        self.func.try_visit_children(&mut f)?;
328        for expr in self.partition_by.iter() {
329            f(expr)?;
330        }
331        for expr in self.order_by.iter() {
332            f(expr)?;
333        }
334        Ok(())
335    }
336
337    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338    where
339        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340        E: From<RecursionLimitError>,
341    {
342        self.func.try_visit_mut_children(&mut f)?;
343        for expr in self.partition_by.iter_mut() {
344            f(expr)?;
345        }
346        for expr in self.order_by.iter_mut() {
347            f(expr)?;
348        }
349        Ok(())
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354/// A window function with its parameters.
355///
356/// There are three types of window functions:
357/// - scalar window functions, which return a different scalar value for each
358///   row within a partition that depends exclusively on the position of the row
359///   within the partition;
360/// - value window functions, which return a scalar value for each row within a
361///   partition that might be computed based on a single row, which is usually not
362///   the current row (e.g., previous or following row; first or last row of the
363///   partition);
364/// - aggregate window functions, which compute a traditional aggregation as a
365///   window function (e.g. `sum(x) OVER (...)`).
366///   (Aggregate window  functions can in some cases be computed by joining the
367///   input relation with a reduction over the same relation that computes the
368///   aggregation using the partition key as its grouping key, but we don't
369///   automatically do this currently.)
370pub enum WindowExprType {
371    Scalar(ScalarWindowExpr),
372    Value(ValueWindowExpr),
373    Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379    where
380        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381    {
382        #[allow(deprecated)]
383        match self {
384            Self::Scalar(expr) => expr.visit_expressions(f),
385            Self::Value(expr) => expr.visit_expressions(f),
386            Self::Aggregate(expr) => expr.visit_expressions(f),
387        }
388    }
389
390    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392    where
393        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394    {
395        #[allow(deprecated)]
396        match self {
397            Self::Scalar(expr) => expr.visit_expressions_mut(f),
398            Self::Value(expr) => expr.visit_expressions_mut(f),
399            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400        }
401    }
402
403    fn typ(
404        &self,
405        outers: &[SqlRelationType],
406        inner: &SqlRelationType,
407        params: &BTreeMap<usize, SqlScalarType>,
408    ) -> SqlColumnType {
409        match self {
410            Self::Scalar(expr) => expr.typ(outers, inner, params),
411            Self::Value(expr) => expr.typ(outers, inner, params),
412            Self::Aggregate(expr) => expr.typ(outers, inner, params),
413        }
414    }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418    fn visit_children<F>(&self, f: F)
419    where
420        F: FnMut(&HirScalarExpr),
421    {
422        match self {
423            Self::Scalar(_) => (),
424            Self::Value(expr) => expr.visit_children(f),
425            Self::Aggregate(expr) => expr.visit_children(f),
426        }
427    }
428
429    fn visit_mut_children<F>(&mut self, f: F)
430    where
431        F: FnMut(&mut HirScalarExpr),
432    {
433        match self {
434            Self::Scalar(_) => (),
435            Self::Value(expr) => expr.visit_mut_children(f),
436            Self::Aggregate(expr) => expr.visit_mut_children(f),
437        }
438    }
439
440    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441    where
442        F: FnMut(&HirScalarExpr) -> Result<(), E>,
443        E: From<RecursionLimitError>,
444    {
445        match self {
446            Self::Scalar(_) => Ok(()),
447            Self::Value(expr) => expr.try_visit_children(f),
448            Self::Aggregate(expr) => expr.try_visit_children(f),
449        }
450    }
451
452    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453    where
454        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455        E: From<RecursionLimitError>,
456    {
457        match self {
458            Self::Scalar(_) => Ok(()),
459            Self::Value(expr) => expr.try_visit_mut_children(f),
460            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461        }
462    }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467    pub func: ScalarWindowFunc,
468    pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474    where
475        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476    {
477        match self.func {
478            ScalarWindowFunc::RowNumber => {}
479            ScalarWindowFunc::Rank => {}
480            ScalarWindowFunc::DenseRank => {}
481        }
482        Ok(())
483    }
484
485    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487    where
488        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489    {
490        match self.func {
491            ScalarWindowFunc::RowNumber => {}
492            ScalarWindowFunc::Rank => {}
493            ScalarWindowFunc::DenseRank => {}
494        }
495        Ok(())
496    }
497
498    fn typ(
499        &self,
500        _outers: &[SqlRelationType],
501        _inner: &SqlRelationType,
502        _params: &BTreeMap<usize, SqlScalarType>,
503    ) -> SqlColumnType {
504        self.func.output_type()
505    }
506
507    pub fn into_expr(self) -> mz_expr::AggregateFunc {
508        match self.func {
509            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510                order_by: self.order_by,
511            },
512            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513                order_by: self.order_by,
514            },
515            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516                order_by: self.order_by,
517            },
518        }
519    }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523/// Scalar Window functions
524pub enum ScalarWindowFunc {
525    RowNumber,
526    Rank,
527    DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532        match self {
533            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534            ScalarWindowFunc::Rank => write!(f, "rank"),
535            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536        }
537    }
538}
539
540impl ScalarWindowFunc {
541    pub fn output_type(&self) -> SqlColumnType {
542        match self {
543            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
544            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
545            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
546        }
547    }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552    pub func: ValueWindowFunc,
553    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
554    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
555    /// e.g., `row(#1, 3, null)`.
556    /// If it's a fused window function, then the arguments of each of the constituent function
557    /// calls are wrapped in an outer record.
558    pub args: Box<HirScalarExpr>,
559    /// See comment on `WindowExpr::order_by`.
560    pub order_by: Vec<ColumnOrder>,
561    pub window_frame: WindowFrame,
562    pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567        match self {
568            ValueWindowFunc::Lag => write!(f, "lag"),
569            ValueWindowFunc::Lead => write!(f, "lead"),
570            ValueWindowFunc::FirstValue => write!(f, "first_value"),
571            ValueWindowFunc::LastValue => write!(f, "last_value"),
572            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573        }
574    }
575}
576
577impl ValueWindowExpr {
578    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580    where
581        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582    {
583        f(&self.args)
584    }
585
586    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588    where
589        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590    {
591        f(&mut self.args)
592    }
593
594    fn typ(
595        &self,
596        outers: &[SqlRelationType],
597        inner: &SqlRelationType,
598        params: &BTreeMap<usize, SqlScalarType>,
599    ) -> SqlColumnType {
600        self.func.output_type(self.args.typ(outers, inner, params))
601    }
602
603    /// Converts into `mz_expr::AggregateFunc`.
604    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605        (
606            self.args,
607            self.func
608                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609        )
610    }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614    fn visit_children<F>(&self, mut f: F)
615    where
616        F: FnMut(&HirScalarExpr),
617    {
618        f(&self.args)
619    }
620
621    fn visit_mut_children<F>(&mut self, mut f: F)
622    where
623        F: FnMut(&mut HirScalarExpr),
624    {
625        f(&mut self.args)
626    }
627
628    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629    where
630        F: FnMut(&HirScalarExpr) -> Result<(), E>,
631        E: From<RecursionLimitError>,
632    {
633        f(&self.args)
634    }
635
636    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637    where
638        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639        E: From<RecursionLimitError>,
640    {
641        f(&mut self.args)
642    }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646/// Value Window functions
647pub enum ValueWindowFunc {
648    Lag,
649    Lead,
650    FirstValue,
651    LastValue,
652    Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
657        match self {
658            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659                // The input is a (value, offset, default) record, so extract the type of the first arg
660                input_type.scalar_type.unwrap_record_element_type()[0]
661                    .clone()
662                    .nullable(true)
663            }
664            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665                input_type.scalar_type.nullable(true)
666            }
667            ValueWindowFunc::Fused(funcs) => {
668                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669                SqlScalarType::Record {
670                    fields: funcs
671                        .iter()
672                        .zip_eq(input_types)
673                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674                        .collect(),
675                    custom_id: None,
676                }
677                .nullable(false)
678            }
679        }
680    }
681
682    pub fn into_expr(
683        self,
684        order_by: Vec<ColumnOrder>,
685        window_frame: WindowFrame,
686        ignore_nulls: bool,
687    ) -> mz_expr::AggregateFunc {
688        match self {
689            // Lag and Lead are fundamentally the same function, just with opposite directions
690            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691                order_by,
692                lag_lead: mz_expr::LagLeadType::Lag,
693                ignore_nulls,
694            },
695            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696                order_by,
697                lag_lead: mz_expr::LagLeadType::Lead,
698                ignore_nulls,
699            },
700            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701                order_by,
702                window_frame,
703            },
704            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705                order_by,
706                window_frame,
707            },
708            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709                funcs: funcs
710                    .into_iter()
711                    .map(|func| {
712                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713                    })
714                    .collect(),
715                order_by,
716            },
717        }
718    }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723    pub aggregate_expr: AggregateExpr,
724    pub order_by: Vec<ColumnOrder>,
725    pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731    where
732        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733    {
734        f(&self.aggregate_expr.expr)
735    }
736
737    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739    where
740        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741    {
742        f(&mut self.aggregate_expr.expr)
743    }
744
745    fn typ(
746        &self,
747        outers: &[SqlRelationType],
748        inner: &SqlRelationType,
749        params: &BTreeMap<usize, SqlScalarType>,
750    ) -> SqlColumnType {
751        self.aggregate_expr
752            .func
753            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754    }
755
756    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758            (
759                self.aggregate_expr.expr,
760                FusedWindowAggregate {
761                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762                    order_by: self.order_by,
763                    window_frame: self.window_frame,
764                },
765            )
766        } else {
767            (
768                self.aggregate_expr.expr,
769                WindowAggregate {
770                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771                    order_by: self.order_by,
772                    window_frame: self.window_frame,
773                },
774            )
775        }
776    }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780    fn visit_children<F>(&self, mut f: F)
781    where
782        F: FnMut(&HirScalarExpr),
783    {
784        f(&self.aggregate_expr.expr)
785    }
786
787    fn visit_mut_children<F>(&mut self, mut f: F)
788    where
789        F: FnMut(&mut HirScalarExpr),
790    {
791        f(&mut self.aggregate_expr.expr)
792    }
793
794    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795    where
796        F: FnMut(&HirScalarExpr) -> Result<(), E>,
797        E: From<RecursionLimitError>,
798    {
799        f(&self.aggregate_expr.expr)
800    }
801
802    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803    where
804        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805        E: From<RecursionLimitError>,
806    {
807        f(&mut self.aggregate_expr.expr)
808    }
809}
810
811/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
812/// determined. Several SQL expressions can be freely coerced based upon where
813/// in the expression tree they appear. For example, the string literal '42'
814/// will be automatically coerced to the integer 42 if used in a numeric
815/// context:
816///
817/// ```sql
818/// SELECT '42' + 42
819/// ```
820///
821/// This separate type gives the code that needs to interact with coercions very
822/// fine-grained control over what coercions happen and when.
823///
824/// The primary driver of coercion is function and operator selection, as
825/// choosing the correct function or operator implementation depends on the type
826/// of the provided arguments. Coercion also occurs at the very root of the
827/// scalar expression tree. For example in
828///
829/// ```sql
830/// SELECT ... WHERE $1
831/// ```
832///
833/// the `WHERE` clause will coerce the contained unconstrained type parameter
834/// `$1` to have type bool.
835#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837    Coerced(HirScalarExpr),
838    Parameter(usize),
839    LiteralNull,
840    LiteralString(String),
841    LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845    pub fn type_as(
846        self,
847        ecx: &ExprContext,
848        ty: &SqlScalarType,
849    ) -> Result<HirScalarExpr, PlanError> {
850        let expr = typeconv::plan_coerce(ecx, self, ty)?;
851        let expr_ty = ecx.scalar_type(&expr);
852        if ty != &expr_ty {
853            sql_bail!(
854                "{} must have type {}, not type {}",
855                ecx.name,
856                ecx.humanize_scalar_type(ty, false),
857                ecx.humanize_scalar_type(&expr_ty, false),
858            );
859        }
860        Ok(expr)
861    }
862
863    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
864        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
865    }
866
867    pub fn cast_to(
868        self,
869        ecx: &ExprContext,
870        ccx: CastContext,
871        ty: &SqlScalarType,
872    ) -> Result<HirScalarExpr, PlanError> {
873        let expr = typeconv::plan_coerce(ecx, self, ty)?;
874        typeconv::plan_cast(ecx, ccx, expr, ty)
875    }
876}
877
878/// The column type for a [`CoercibleScalarExpr`].
879#[derive(Clone, Debug)]
880pub enum CoercibleColumnType {
881    Coerced(SqlColumnType),
882    Record(Vec<CoercibleColumnType>),
883    Uncoerced,
884}
885
886impl CoercibleColumnType {
887    /// Reports the nullability of the type.
888    pub fn nullable(&self) -> bool {
889        match self {
890            // A coerced value's nullability is known.
891            CoercibleColumnType::Coerced(ct) => ct.nullable,
892
893            // A literal record can never be null.
894            CoercibleColumnType::Record(_) => false,
895
896            // An uncoerced literal may be the literal `NULL`, so we have
897            // to conservatively assume it is nullable.
898            CoercibleColumnType::Uncoerced => true,
899        }
900    }
901}
902
903/// The scalar type for a [`CoercibleScalarExpr`].
904#[derive(Clone, Debug)]
905pub enum CoercibleScalarType {
906    Coerced(SqlScalarType),
907    Record(Vec<CoercibleColumnType>),
908    Uncoerced,
909}
910
911impl CoercibleScalarType {
912    /// Reports whether the scalar type has been coerced.
913    pub fn is_coerced(&self) -> bool {
914        matches!(self, CoercibleScalarType::Coerced(_))
915    }
916
917    /// Returns the coerced scalar type, if the type is coerced.
918    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
919        match self {
920            CoercibleScalarType::Coerced(t) => Some(t),
921            _ => None,
922        }
923    }
924
925    /// If the type is coerced, apply the mapping function to the contained
926    /// scalar type.
927    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
928    where
929        F: FnOnce(SqlScalarType) -> SqlScalarType,
930    {
931        match self {
932            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
933            _ => self,
934        }
935    }
936
937    /// If the type is an coercible record, forcibly converts to a coerced
938    /// record type. Any uncoerced field types are assumed to be of type text.
939    ///
940    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
941    /// accepts a type hint that can indicate the types of uncoerced field
942    /// types.
943    pub fn force_coerced_if_record(&mut self) {
944        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
945            let mut fields = vec![];
946            for (i, uf) in uncoerced_fields.enumerate() {
947                let name = ColumnName::from(format!("f{}", i + 1));
948                let ty = match uf {
949                    CoercibleColumnType::Coerced(ty) => ty,
950                    CoercibleColumnType::Record(mut fields) => {
951                        convert(fields.drain(..)).nullable(false)
952                    }
953                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
954                };
955                fields.push((name, ty))
956            }
957            SqlScalarType::Record {
958                fields: fields.into(),
959                custom_id: None,
960            }
961        }
962
963        if let CoercibleScalarType::Record(fields) = self {
964            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
965        }
966    }
967}
968
969/// An expression whose type can be ascertained.
970///
971/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
972pub trait AbstractExpr {
973    type Type: AbstractColumnType;
974
975    /// Computes the type of the expression.
976    fn typ(
977        &self,
978        outers: &[SqlRelationType],
979        inner: &SqlRelationType,
980        params: &BTreeMap<usize, SqlScalarType>,
981    ) -> Self::Type;
982}
983
984impl AbstractExpr for CoercibleScalarExpr {
985    type Type = CoercibleColumnType;
986
987    fn typ(
988        &self,
989        outers: &[SqlRelationType],
990        inner: &SqlRelationType,
991        params: &BTreeMap<usize, SqlScalarType>,
992    ) -> Self::Type {
993        match self {
994            CoercibleScalarExpr::Coerced(expr) => {
995                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
996            }
997            CoercibleScalarExpr::LiteralRecord(scalars) => {
998                let fields = scalars
999                    .iter()
1000                    .map(|s| s.typ(outers, inner, params))
1001                    .collect();
1002                CoercibleColumnType::Record(fields)
1003            }
1004            _ => CoercibleColumnType::Uncoerced,
1005        }
1006    }
1007}
1008
1009/// A column type-like object whose underlying scalar type-like object can be
1010/// ascertained.
1011///
1012/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1013pub trait AbstractColumnType {
1014    type AbstractScalarType;
1015
1016    /// Converts the column type-like object into its inner scalar type-like
1017    /// object.
1018    fn scalar_type(self) -> Self::AbstractScalarType;
1019}
1020
1021impl AbstractColumnType for SqlColumnType {
1022    type AbstractScalarType = SqlScalarType;
1023
1024    fn scalar_type(self) -> Self::AbstractScalarType {
1025        self.scalar_type
1026    }
1027}
1028
1029impl AbstractColumnType for CoercibleColumnType {
1030    type AbstractScalarType = CoercibleScalarType;
1031
1032    fn scalar_type(self) -> Self::AbstractScalarType {
1033        match self {
1034            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1035            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1036            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1037        }
1038    }
1039}
1040
1041impl From<HirScalarExpr> for CoercibleScalarExpr {
1042    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1043        CoercibleScalarExpr::Coerced(expr)
1044    }
1045}
1046
1047/// A leveled column reference.
1048///
1049/// In the course of decorrelation, multiple levels of nested subqueries are
1050/// traversed, and references to columns may correspond to different levels
1051/// of containing outer subqueries.
1052///
1053/// A `ColumnRef` allows expressions to refer to columns while being clear
1054/// about which level the column references without manually performing the
1055/// bookkeeping tracking their actual column locations.
1056///
1057/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1058/// from the reference, using `column` as a unique identifier in that subquery level.
1059/// A `level` of zero corresponds to the current scope, and levels increase to
1060/// indicate subqueries further "outwards".
1061#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1062pub struct ColumnRef {
1063    // scope level, where 0 is the current scope and 1+ are outer scopes.
1064    pub level: usize,
1065    // level-local column identifier used.
1066    pub column: usize,
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1070pub enum JoinKind {
1071    Inner,
1072    LeftOuter,
1073    RightOuter,
1074    FullOuter,
1075}
1076
1077impl fmt::Display for JoinKind {
1078    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1079        write!(
1080            f,
1081            "{}",
1082            match self {
1083                JoinKind::Inner => "Inner",
1084                JoinKind::LeftOuter => "LeftOuter",
1085                JoinKind::RightOuter => "RightOuter",
1086                JoinKind::FullOuter => "FullOuter",
1087            }
1088        )
1089    }
1090}
1091
1092impl JoinKind {
1093    pub fn can_be_correlated(&self) -> bool {
1094        match self {
1095            JoinKind::Inner | JoinKind::LeftOuter => true,
1096            JoinKind::RightOuter | JoinKind::FullOuter => false,
1097        }
1098    }
1099}
1100
1101#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1102pub struct AggregateExpr {
1103    pub func: AggregateFunc,
1104    pub expr: Box<HirScalarExpr>,
1105    pub distinct: bool,
1106}
1107
1108/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1109/// types may be different.
1110///
1111/// Specifically, the nullability of the aggregate columns is more common
1112/// here than in `expr`, as these aggregates may be applied over empty
1113/// result sets and should be null in those cases, whereas `expr` variants
1114/// only return null values when supplied nulls as input.
1115#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1116pub enum AggregateFunc {
1117    MaxNumeric,
1118    MaxInt16,
1119    MaxInt32,
1120    MaxInt64,
1121    MaxUInt16,
1122    MaxUInt32,
1123    MaxUInt64,
1124    MaxMzTimestamp,
1125    MaxFloat32,
1126    MaxFloat64,
1127    MaxBool,
1128    MaxString,
1129    MaxDate,
1130    MaxTimestamp,
1131    MaxTimestampTz,
1132    MaxInterval,
1133    MaxTime,
1134    MinNumeric,
1135    MinInt16,
1136    MinInt32,
1137    MinInt64,
1138    MinUInt16,
1139    MinUInt32,
1140    MinUInt64,
1141    MinMzTimestamp,
1142    MinFloat32,
1143    MinFloat64,
1144    MinBool,
1145    MinString,
1146    MinDate,
1147    MinTimestamp,
1148    MinTimestampTz,
1149    MinInterval,
1150    MinTime,
1151    SumInt16,
1152    SumInt32,
1153    SumInt64,
1154    SumUInt16,
1155    SumUInt32,
1156    SumUInt64,
1157    SumFloat32,
1158    SumFloat64,
1159    SumNumeric,
1160    Count,
1161    Any,
1162    All,
1163    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1164    /// into a JSON list. The other elements are columns used by `order_by`.
1165    ///
1166    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1167    /// layer, this function filters out `Datum::Null`, for consistency with
1168    /// the other aggregate functions.
1169    JsonbAgg {
1170        order_by: Vec<ColumnOrder>,
1171    },
1172    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1173    /// JSON map. The other elements are columns used by `order_by`.
1174    JsonbObjectAgg {
1175        order_by: Vec<ColumnOrder>,
1176    },
1177    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1178    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1179    /// elements are columns used by `order_by`.
1180    MapAgg {
1181        order_by: Vec<ColumnOrder>,
1182        value_type: SqlScalarType,
1183    },
1184    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1185    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1186    ArrayConcat {
1187        order_by: Vec<ColumnOrder>,
1188    },
1189    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1190    /// single `Datum::List`. The other elements are columns used by `order_by`.
1191    ListConcat {
1192        order_by: Vec<ColumnOrder>,
1193    },
1194    StringAgg {
1195        order_by: Vec<ColumnOrder>,
1196    },
1197    /// A bundle of fused window aggregations: its input is a record, whose each
1198    /// component will be the input to one of the `AggregateFunc`s.
1199    ///
1200    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1201    /// more specifically an `AggregateWindowExpr`.
1202    FusedWindowAgg {
1203        funcs: Vec<AggregateFunc>,
1204    },
1205    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1206    ///
1207    /// Useful for removing an expensive aggregation while maintaining the shape
1208    /// of a reduce operator.
1209    Dummy,
1210}
1211
1212impl AggregateFunc {
1213    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1214    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1215        match self {
1216            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1217            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1218            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1219            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1220            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1221            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1222            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1223            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1224            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1225            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1226            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1227            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1228            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1229            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1230            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1231            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1232            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1233            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1234            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1235            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1236            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1237            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1238            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1239            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1240            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1241            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1242            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1243            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1244            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1245            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1246            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1247            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1248            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1249            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1250            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1251            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1252            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1253            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1254            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1255            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1256            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1257            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1258            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1259            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1260            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1261            AggregateFunc::All => mz_expr::AggregateFunc::All,
1262            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1263            AggregateFunc::JsonbObjectAgg { order_by } => {
1264                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1265            }
1266            AggregateFunc::MapAgg {
1267                order_by,
1268                value_type,
1269            } => mz_expr::AggregateFunc::MapAgg {
1270                order_by,
1271                value_type,
1272            },
1273            AggregateFunc::ArrayConcat { order_by } => {
1274                mz_expr::AggregateFunc::ArrayConcat { order_by }
1275            }
1276            AggregateFunc::ListConcat { order_by } => {
1277                mz_expr::AggregateFunc::ListConcat { order_by }
1278            }
1279            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1280            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1281            // `AggregateWindowExpr::into_expr`.
1282            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1283                panic!("into_expr called on FusedWindowAgg")
1284            }
1285            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1286        }
1287    }
1288
1289    /// Returns a datum whose inclusion in the aggregation will not change its
1290    /// result.
1291    ///
1292    /// # Panics
1293    ///
1294    /// Panics if called on a `FusedWindowAgg`.
1295    pub fn identity_datum(&self) -> Datum<'static> {
1296        match self {
1297            AggregateFunc::Any => Datum::False,
1298            AggregateFunc::All => Datum::True,
1299            AggregateFunc::Dummy => Datum::Dummy,
1300            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1301            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1302            AggregateFunc::MaxNumeric
1303            | AggregateFunc::MaxInt16
1304            | AggregateFunc::MaxInt32
1305            | AggregateFunc::MaxInt64
1306            | AggregateFunc::MaxUInt16
1307            | AggregateFunc::MaxUInt32
1308            | AggregateFunc::MaxUInt64
1309            | AggregateFunc::MaxMzTimestamp
1310            | AggregateFunc::MaxFloat32
1311            | AggregateFunc::MaxFloat64
1312            | AggregateFunc::MaxBool
1313            | AggregateFunc::MaxString
1314            | AggregateFunc::MaxDate
1315            | AggregateFunc::MaxTimestamp
1316            | AggregateFunc::MaxTimestampTz
1317            | AggregateFunc::MaxInterval
1318            | AggregateFunc::MaxTime
1319            | AggregateFunc::MinNumeric
1320            | AggregateFunc::MinInt16
1321            | AggregateFunc::MinInt32
1322            | AggregateFunc::MinInt64
1323            | AggregateFunc::MinUInt16
1324            | AggregateFunc::MinUInt32
1325            | AggregateFunc::MinUInt64
1326            | AggregateFunc::MinMzTimestamp
1327            | AggregateFunc::MinFloat32
1328            | AggregateFunc::MinFloat64
1329            | AggregateFunc::MinBool
1330            | AggregateFunc::MinString
1331            | AggregateFunc::MinDate
1332            | AggregateFunc::MinTimestamp
1333            | AggregateFunc::MinTimestampTz
1334            | AggregateFunc::MinInterval
1335            | AggregateFunc::MinTime
1336            | AggregateFunc::SumInt16
1337            | AggregateFunc::SumInt32
1338            | AggregateFunc::SumInt64
1339            | AggregateFunc::SumUInt16
1340            | AggregateFunc::SumUInt32
1341            | AggregateFunc::SumUInt64
1342            | AggregateFunc::SumFloat32
1343            | AggregateFunc::SumFloat64
1344            | AggregateFunc::SumNumeric
1345            | AggregateFunc::Count
1346            | AggregateFunc::JsonbAgg { .. }
1347            | AggregateFunc::JsonbObjectAgg { .. }
1348            | AggregateFunc::MapAgg { .. }
1349            | AggregateFunc::StringAgg { .. } => Datum::Null,
1350            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1351                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1352                // in HIR planning, because it is introduced only during HIR transformation.
1353                //
1354                // The implementation could be something like the following, except that we need to
1355                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1356                // ```
1357                // let temp_storage = RowArena::new();
1358                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1359                // ```
1360                panic!("FusedWindowAgg doesn't have an identity_datum")
1361            }
1362        }
1363    }
1364
1365    /// The output column type for the result of an aggregation.
1366    ///
1367    /// The output column type also contains nullability information, which
1368    /// is (without further information) true for aggregations that are not
1369    /// counts.
1370    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1371        let scalar_type = match self {
1372            AggregateFunc::Count => SqlScalarType::Int64,
1373            AggregateFunc::Any => SqlScalarType::Bool,
1374            AggregateFunc::All => SqlScalarType::Bool,
1375            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1376            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1377            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1378            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1379            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1380                max_scale: Some(NumericMaxScale::ZERO),
1381            },
1382            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1383            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1384                max_scale: Some(NumericMaxScale::ZERO),
1385            },
1386            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1387                value_type: Box::new(value_type.clone()),
1388                custom_id: None,
1389            },
1390            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1391                match input_type.scalar_type {
1392                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1393                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1394                    _ => unreachable!(),
1395                }
1396            }
1397            AggregateFunc::MaxNumeric
1398            | AggregateFunc::MaxInt16
1399            | AggregateFunc::MaxInt32
1400            | AggregateFunc::MaxInt64
1401            | AggregateFunc::MaxUInt16
1402            | AggregateFunc::MaxUInt32
1403            | AggregateFunc::MaxUInt64
1404            | AggregateFunc::MaxMzTimestamp
1405            | AggregateFunc::MaxFloat32
1406            | AggregateFunc::MaxFloat64
1407            | AggregateFunc::MaxBool
1408            | AggregateFunc::MaxString
1409            | AggregateFunc::MaxDate
1410            | AggregateFunc::MaxTimestamp
1411            | AggregateFunc::MaxTimestampTz
1412            | AggregateFunc::MaxInterval
1413            | AggregateFunc::MaxTime
1414            | AggregateFunc::MinNumeric
1415            | AggregateFunc::MinInt16
1416            | AggregateFunc::MinInt32
1417            | AggregateFunc::MinInt64
1418            | AggregateFunc::MinUInt16
1419            | AggregateFunc::MinUInt32
1420            | AggregateFunc::MinUInt64
1421            | AggregateFunc::MinMzTimestamp
1422            | AggregateFunc::MinFloat32
1423            | AggregateFunc::MinFloat64
1424            | AggregateFunc::MinBool
1425            | AggregateFunc::MinString
1426            | AggregateFunc::MinDate
1427            | AggregateFunc::MinTimestamp
1428            | AggregateFunc::MinTimestampTz
1429            | AggregateFunc::MinInterval
1430            | AggregateFunc::MinTime
1431            | AggregateFunc::SumFloat32
1432            | AggregateFunc::SumFloat64
1433            | AggregateFunc::SumNumeric
1434            | AggregateFunc::Dummy => input_type.scalar_type,
1435            AggregateFunc::FusedWindowAgg { funcs } => {
1436                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1437                SqlScalarType::Record {
1438                    fields: funcs
1439                        .iter()
1440                        .zip_eq(input_types)
1441                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1442                        .collect(),
1443                    custom_id: None,
1444                }
1445            }
1446        };
1447        // max/min/sum return null on empty sets
1448        let nullable = !matches!(self, AggregateFunc::Count);
1449        scalar_type.nullable(nullable)
1450    }
1451
1452    pub fn is_order_sensitive(&self) -> bool {
1453        use AggregateFunc::*;
1454        matches!(
1455            self,
1456            JsonbAgg { .. }
1457                | JsonbObjectAgg { .. }
1458                | MapAgg { .. }
1459                | ArrayConcat { .. }
1460                | ListConcat { .. }
1461                | StringAgg { .. }
1462        )
1463    }
1464}
1465
1466impl HirRelationExpr {
1467    /// Gets the SQL type of a self-contained, top-level expression.
1468    pub fn top_level_typ(&self) -> SqlRelationType {
1469        self.typ(&[], &BTreeMap::new())
1470    }
1471
1472    /// Gets the SQL type of the expression.
1473    ///
1474    /// `outers` gives types for outer relations.
1475    /// `params` gives types for parameters.
1476    pub fn typ(
1477        &self,
1478        outers: &[SqlRelationType],
1479        params: &BTreeMap<usize, SqlScalarType>,
1480    ) -> SqlRelationType {
1481        stack::maybe_grow(|| match self {
1482            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1483            HirRelationExpr::Get { typ, .. } => typ.clone(),
1484            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1485            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1486            HirRelationExpr::Project { input, outputs } => {
1487                let input_typ = input.typ(outers, params);
1488                SqlRelationType::new(
1489                    outputs
1490                        .iter()
1491                        .map(|&i| input_typ.column_types[i].clone())
1492                        .collect(),
1493                )
1494            }
1495            HirRelationExpr::Map { input, scalars } => {
1496                let mut typ = input.typ(outers, params);
1497                for scalar in scalars {
1498                    typ.column_types.push(scalar.typ(outers, &typ, params));
1499                }
1500                typ
1501            }
1502            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1503            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1504                input.typ(outers, params)
1505            }
1506            HirRelationExpr::Join {
1507                left, right, kind, ..
1508            } => {
1509                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1510                let right_nullable =
1511                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1512                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1513                    let nullable = t.nullable || left_nullable;
1514                    t.nullable(nullable)
1515                });
1516                let mut outers = outers.to_vec();
1517                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1518                let rt = right
1519                    .typ(&outers, params)
1520                    .column_types
1521                    .into_iter()
1522                    .map(|t| {
1523                        let nullable = t.nullable || right_nullable;
1524                        t.nullable(nullable)
1525                    });
1526                SqlRelationType::new(lt.chain(rt).collect())
1527            }
1528            HirRelationExpr::Reduce {
1529                input,
1530                group_key,
1531                aggregates,
1532                expected_group_size: _,
1533            } => {
1534                let input_typ = input.typ(outers, params);
1535                let mut column_types = group_key
1536                    .iter()
1537                    .map(|&i| input_typ.column_types[i].clone())
1538                    .collect::<Vec<_>>();
1539                for agg in aggregates {
1540                    column_types.push(agg.typ(outers, &input_typ, params));
1541                }
1542                // TODO(frank): add primary key information.
1543                SqlRelationType::new(column_types)
1544            }
1545            // TODO(frank): check for removal; add primary key information.
1546            HirRelationExpr::Distinct { input }
1547            | HirRelationExpr::Negate { input }
1548            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1549            HirRelationExpr::Union { base, inputs } => {
1550                let mut base_cols = base.typ(outers, params).column_types;
1551                for input in inputs {
1552                    for (base_col, col) in base_cols
1553                        .iter_mut()
1554                        .zip_eq(input.typ(outers, params).column_types)
1555                    {
1556                        *base_col = base_col.union(&col).unwrap();
1557                    }
1558                }
1559                SqlRelationType::new(base_cols)
1560            }
1561        })
1562    }
1563
1564    pub fn arity(&self) -> usize {
1565        match self {
1566            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1567            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1568            HirRelationExpr::Let { body, .. } => body.arity(),
1569            HirRelationExpr::LetRec { body, .. } => body.arity(),
1570            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1571            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1572            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1573            HirRelationExpr::Filter { input, .. }
1574            | HirRelationExpr::TopK { input, .. }
1575            | HirRelationExpr::Distinct { input }
1576            | HirRelationExpr::Negate { input }
1577            | HirRelationExpr::Threshold { input } => input.arity(),
1578            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1579            HirRelationExpr::Union { base, .. } => base.arity(),
1580            HirRelationExpr::Reduce {
1581                group_key,
1582                aggregates,
1583                ..
1584            } => group_key.len() + aggregates.len(),
1585        }
1586    }
1587
1588    /// If self is a constant, return the value and the type, otherwise `None`.
1589    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1590        match self {
1591            Self::Constant { rows, typ } => Some((rows, typ)),
1592            _ => None,
1593        }
1594    }
1595
1596    /// Reports whether this expression contains a column reference to its
1597    /// direct parent scope.
1598    pub fn is_correlated(&self) -> bool {
1599        let mut correlated = false;
1600        #[allow(deprecated)]
1601        self.visit_columns(0, &mut |depth, col| {
1602            if col.level > depth && col.level - depth == 1 {
1603                correlated = true;
1604            }
1605        });
1606        correlated
1607    }
1608
1609    pub fn is_join_identity(&self) -> bool {
1610        match self {
1611            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1612            _ => false,
1613        }
1614    }
1615
1616    pub fn project(self, outputs: Vec<usize>) -> Self {
1617        if outputs.iter().copied().eq(0..self.arity()) {
1618            // The projection is trivial. Suppress it.
1619            self
1620        } else {
1621            HirRelationExpr::Project {
1622                input: Box::new(self),
1623                outputs,
1624            }
1625        }
1626    }
1627
1628    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1629        if scalars.is_empty() {
1630            // The map is trivial. Suppress it.
1631            self
1632        } else if let HirRelationExpr::Map {
1633            scalars: old_scalars,
1634            input: _,
1635        } = &mut self
1636        {
1637            // Map applied to a map. Fuse the maps.
1638            old_scalars.extend(scalars);
1639            self
1640        } else {
1641            HirRelationExpr::Map {
1642                input: Box::new(self),
1643                scalars,
1644            }
1645        }
1646    }
1647
1648    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1649        if let HirRelationExpr::Filter {
1650            input: _,
1651            predicates,
1652        } = &mut self
1653        {
1654            predicates.extend(preds);
1655            predicates.sort();
1656            predicates.dedup();
1657            self
1658        } else {
1659            preds.sort();
1660            preds.dedup();
1661            HirRelationExpr::Filter {
1662                input: Box::new(self),
1663                predicates: preds,
1664            }
1665        }
1666    }
1667
1668    pub fn reduce(
1669        self,
1670        group_key: Vec<usize>,
1671        aggregates: Vec<AggregateExpr>,
1672        expected_group_size: Option<u64>,
1673    ) -> Self {
1674        HirRelationExpr::Reduce {
1675            input: Box::new(self),
1676            group_key,
1677            aggregates,
1678            expected_group_size,
1679        }
1680    }
1681
1682    pub fn top_k(
1683        self,
1684        group_key: Vec<usize>,
1685        order_key: Vec<ColumnOrder>,
1686        limit: Option<HirScalarExpr>,
1687        offset: HirScalarExpr,
1688        expected_group_size: Option<u64>,
1689    ) -> Self {
1690        HirRelationExpr::TopK {
1691            input: Box::new(self),
1692            group_key,
1693            order_key,
1694            limit,
1695            offset,
1696            expected_group_size,
1697        }
1698    }
1699
1700    pub fn negate(self) -> Self {
1701        if let HirRelationExpr::Negate { input } = self {
1702            *input
1703        } else {
1704            HirRelationExpr::Negate {
1705                input: Box::new(self),
1706            }
1707        }
1708    }
1709
1710    pub fn distinct(self) -> Self {
1711        if let HirRelationExpr::Distinct { .. } = self {
1712            self
1713        } else {
1714            HirRelationExpr::Distinct {
1715                input: Box::new(self),
1716            }
1717        }
1718    }
1719
1720    pub fn threshold(self) -> Self {
1721        if let HirRelationExpr::Threshold { .. } = self {
1722            self
1723        } else {
1724            HirRelationExpr::Threshold {
1725                input: Box::new(self),
1726            }
1727        }
1728    }
1729
1730    pub fn union(self, other: Self) -> Self {
1731        let mut terms = Vec::new();
1732        if let HirRelationExpr::Union { base, inputs } = self {
1733            terms.push(*base);
1734            terms.extend(inputs);
1735        } else {
1736            terms.push(self);
1737        }
1738        if let HirRelationExpr::Union { base, inputs } = other {
1739            terms.push(*base);
1740            terms.extend(inputs);
1741        } else {
1742            terms.push(other);
1743        }
1744        HirRelationExpr::Union {
1745            base: Box::new(terms.remove(0)),
1746            inputs: terms,
1747        }
1748    }
1749
1750    pub fn exists(self) -> HirScalarExpr {
1751        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1752    }
1753
1754    pub fn select(self) -> HirScalarExpr {
1755        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1756    }
1757
1758    pub fn join(
1759        self,
1760        mut right: HirRelationExpr,
1761        on: HirScalarExpr,
1762        kind: JoinKind,
1763    ) -> HirRelationExpr {
1764        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1765        {
1766            // The join can be elided, but we need to adjust column references
1767            // on the right-hand side to account for the removal of the scope
1768            // introduced by the join.
1769            #[allow(deprecated)]
1770            right.visit_columns_mut(0, &mut |depth, col| {
1771                if col.level > depth {
1772                    col.level -= 1;
1773                }
1774            });
1775            right
1776        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1777            self
1778        } else {
1779            HirRelationExpr::Join {
1780                left: Box::new(self),
1781                right: Box::new(right),
1782                on,
1783                kind,
1784            }
1785        }
1786    }
1787
1788    pub fn take(&mut self) -> HirRelationExpr {
1789        mem::replace(
1790            self,
1791            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1792        )
1793    }
1794
1795    #[deprecated = "Use `Visit::visit_post`."]
1796    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1797    where
1798        F: FnMut(&'a Self, usize),
1799    {
1800        #[allow(deprecated)]
1801        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1802                                                 depth: usize|
1803         -> Result<(), ()> {
1804            f(e, depth);
1805            Ok(())
1806        });
1807    }
1808
1809    #[deprecated = "Use `Visit::try_visit_post`."]
1810    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1811    where
1812        F: FnMut(&'a Self, usize) -> Result<(), E>,
1813    {
1814        #[allow(deprecated)]
1815        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1816            e.visit_fallible(depth, f)
1817        })?;
1818        f(self, depth)
1819    }
1820
1821    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1822    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1823    where
1824        F: FnMut(&'a Self, usize) -> Result<(), E>,
1825    {
1826        match self {
1827            HirRelationExpr::Constant { .. }
1828            | HirRelationExpr::Get { .. }
1829            | HirRelationExpr::CallTable { .. } => (),
1830            HirRelationExpr::Let { body, value, .. } => {
1831                f(value, depth)?;
1832                f(body, depth)?;
1833            }
1834            HirRelationExpr::LetRec {
1835                limit: _,
1836                bindings,
1837                body,
1838            } => {
1839                for (_, _, value, _) in bindings.iter() {
1840                    f(value, depth)?;
1841                }
1842                f(body, depth)?;
1843            }
1844            HirRelationExpr::Project { input, .. } => {
1845                f(input, depth)?;
1846            }
1847            HirRelationExpr::Map { input, .. } => {
1848                f(input, depth)?;
1849            }
1850            HirRelationExpr::Filter { input, .. } => {
1851                f(input, depth)?;
1852            }
1853            HirRelationExpr::Join { left, right, .. } => {
1854                f(left, depth)?;
1855                f(right, depth + 1)?;
1856            }
1857            HirRelationExpr::Reduce { input, .. } => {
1858                f(input, depth)?;
1859            }
1860            HirRelationExpr::Distinct { input } => {
1861                f(input, depth)?;
1862            }
1863            HirRelationExpr::TopK { input, .. } => {
1864                f(input, depth)?;
1865            }
1866            HirRelationExpr::Negate { input } => {
1867                f(input, depth)?;
1868            }
1869            HirRelationExpr::Threshold { input } => {
1870                f(input, depth)?;
1871            }
1872            HirRelationExpr::Union { base, inputs } => {
1873                f(base, depth)?;
1874                for input in inputs {
1875                    f(input, depth)?;
1876                }
1877            }
1878        }
1879        Ok(())
1880    }
1881
1882    #[deprecated = "Use `Visit::visit_mut_post` instead."]
1883    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1884    where
1885        F: FnMut(&mut Self, usize),
1886    {
1887        #[allow(deprecated)]
1888        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1889                                                     depth: usize|
1890         -> Result<(), ()> {
1891            f(e, depth);
1892            Ok(())
1893        });
1894    }
1895
1896    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1897    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1898    where
1899        F: FnMut(&mut Self, usize) -> Result<(), E>,
1900    {
1901        #[allow(deprecated)]
1902        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1903            e.visit_mut_fallible(depth, f)
1904        })?;
1905        f(self, depth)
1906    }
1907
1908    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1909    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1910    where
1911        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1912    {
1913        match self {
1914            HirRelationExpr::Constant { .. }
1915            | HirRelationExpr::Get { .. }
1916            | HirRelationExpr::CallTable { .. } => (),
1917            HirRelationExpr::Let { body, value, .. } => {
1918                f(value, depth)?;
1919                f(body, depth)?;
1920            }
1921            HirRelationExpr::LetRec {
1922                limit: _,
1923                bindings,
1924                body,
1925            } => {
1926                for (_, _, value, _) in bindings.iter_mut() {
1927                    f(value, depth)?;
1928                }
1929                f(body, depth)?;
1930            }
1931            HirRelationExpr::Project { input, .. } => {
1932                f(input, depth)?;
1933            }
1934            HirRelationExpr::Map { input, .. } => {
1935                f(input, depth)?;
1936            }
1937            HirRelationExpr::Filter { input, .. } => {
1938                f(input, depth)?;
1939            }
1940            HirRelationExpr::Join { left, right, .. } => {
1941                f(left, depth)?;
1942                f(right, depth + 1)?;
1943            }
1944            HirRelationExpr::Reduce { input, .. } => {
1945                f(input, depth)?;
1946            }
1947            HirRelationExpr::Distinct { input } => {
1948                f(input, depth)?;
1949            }
1950            HirRelationExpr::TopK { input, .. } => {
1951                f(input, depth)?;
1952            }
1953            HirRelationExpr::Negate { input } => {
1954                f(input, depth)?;
1955            }
1956            HirRelationExpr::Threshold { input } => {
1957                f(input, depth)?;
1958            }
1959            HirRelationExpr::Union { base, inputs } => {
1960                f(base, depth)?;
1961                for input in inputs {
1962                    f(input, depth)?;
1963                }
1964            }
1965        }
1966        Ok(())
1967    }
1968
1969    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1970    /// Visits all scalar expressions within the sub-tree of the given relation.
1971    ///
1972    /// The `depth` argument should indicate the subquery nesting depth of the expression,
1973    /// which will be incremented when entering the RHS of a join or a subquery and
1974    /// presented to the supplied function `f`.
1975    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1976    where
1977        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1978    {
1979        #[allow(deprecated)]
1980        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1981                                         depth: usize|
1982         -> Result<(), E> {
1983            match e {
1984                HirRelationExpr::Join { on, .. } => {
1985                    f(on, depth)?;
1986                }
1987                HirRelationExpr::Map { scalars, .. } => {
1988                    for scalar in scalars {
1989                        f(scalar, depth)?;
1990                    }
1991                }
1992                HirRelationExpr::CallTable { exprs, .. } => {
1993                    for expr in exprs {
1994                        f(expr, depth)?;
1995                    }
1996                }
1997                HirRelationExpr::Filter { predicates, .. } => {
1998                    for predicate in predicates {
1999                        f(predicate, depth)?;
2000                    }
2001                }
2002                HirRelationExpr::Reduce { aggregates, .. } => {
2003                    for aggregate in aggregates {
2004                        f(&aggregate.expr, depth)?;
2005                    }
2006                }
2007                HirRelationExpr::TopK { limit, offset, .. } => {
2008                    if let Some(limit) = limit {
2009                        f(limit, depth)?;
2010                    }
2011                    f(offset, depth)?;
2012                }
2013                HirRelationExpr::Union { .. }
2014                | HirRelationExpr::Let { .. }
2015                | HirRelationExpr::LetRec { .. }
2016                | HirRelationExpr::Project { .. }
2017                | HirRelationExpr::Distinct { .. }
2018                | HirRelationExpr::Negate { .. }
2019                | HirRelationExpr::Threshold { .. }
2020                | HirRelationExpr::Constant { .. }
2021                | HirRelationExpr::Get { .. } => (),
2022            }
2023            Ok(())
2024        })
2025    }
2026
2027    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2028    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2029    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2030    where
2031        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2032    {
2033        #[allow(deprecated)]
2034        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2035                                             depth: usize|
2036         -> Result<(), E> {
2037            match e {
2038                HirRelationExpr::Join { on, .. } => {
2039                    f(on, depth)?;
2040                }
2041                HirRelationExpr::Map { scalars, .. } => {
2042                    for scalar in scalars.iter_mut() {
2043                        f(scalar, depth)?;
2044                    }
2045                }
2046                HirRelationExpr::CallTable { exprs, .. } => {
2047                    for expr in exprs.iter_mut() {
2048                        f(expr, depth)?;
2049                    }
2050                }
2051                HirRelationExpr::Filter { predicates, .. } => {
2052                    for predicate in predicates.iter_mut() {
2053                        f(predicate, depth)?;
2054                    }
2055                }
2056                HirRelationExpr::Reduce { aggregates, .. } => {
2057                    for aggregate in aggregates.iter_mut() {
2058                        f(&mut aggregate.expr, depth)?;
2059                    }
2060                }
2061                HirRelationExpr::TopK { limit, offset, .. } => {
2062                    if let Some(limit) = limit {
2063                        f(limit, depth)?;
2064                    }
2065                    f(offset, depth)?;
2066                }
2067                HirRelationExpr::Union { .. }
2068                | HirRelationExpr::Let { .. }
2069                | HirRelationExpr::LetRec { .. }
2070                | HirRelationExpr::Project { .. }
2071                | HirRelationExpr::Distinct { .. }
2072                | HirRelationExpr::Negate { .. }
2073                | HirRelationExpr::Threshold { .. }
2074                | HirRelationExpr::Constant { .. }
2075                | HirRelationExpr::Get { .. } => (),
2076            }
2077            Ok(())
2078        })
2079    }
2080
2081    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2082    /// Visits the column references in this relation expression.
2083    ///
2084    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2085    /// which will be incremented when entering the RHS of a join or a subquery and
2086    /// presented to the supplied function `f`.
2087    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2088    where
2089        F: FnMut(usize, &ColumnRef),
2090    {
2091        #[allow(deprecated)]
2092        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2093                                                           depth: usize|
2094         -> Result<(), ()> {
2095            e.visit_columns(depth, f);
2096            Ok(())
2097        });
2098    }
2099
2100    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2101    /// Like `visit_columns`, but permits mutating the column references.
2102    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2103    where
2104        F: FnMut(usize, &mut ColumnRef),
2105    {
2106        #[allow(deprecated)]
2107        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2108                                                               depth: usize|
2109         -> Result<(), ()> {
2110            e.visit_columns_mut(depth, f);
2111            Ok(())
2112        });
2113    }
2114
2115    /// Replaces any parameter references in the expression with the
2116    /// corresponding datum from `params`.
2117    pub fn bind_parameters(
2118        &mut self,
2119        scx: &StatementContext,
2120        lifetime: QueryLifetime,
2121        params: &Params,
2122    ) -> Result<(), PlanError> {
2123        #[allow(deprecated)]
2124        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2125            e.bind_parameters(scx, lifetime, params)
2126        })
2127    }
2128
2129    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2130        let mut contains_parameters = false;
2131        #[allow(deprecated)]
2132        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2133            if e.contains_parameters() {
2134                contains_parameters = true;
2135            }
2136            Ok::<(), PlanError>(())
2137        })?;
2138        Ok(contains_parameters)
2139    }
2140
2141    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2142    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2143        #[allow(deprecated)]
2144        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2145                                                               depth: usize|
2146         -> Result<(), ()> {
2147            e.splice_parameters(params, depth);
2148            Ok(())
2149        });
2150    }
2151
2152    /// Constructs a constant collection from specific rows and schema.
2153    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2154        let rows = rows
2155            .into_iter()
2156            .map(move |datums| Row::pack_slice(&datums))
2157            .collect();
2158        HirRelationExpr::Constant { rows, typ }
2159    }
2160
2161    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2162    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2163    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2164    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2165    /// finishing into a trivial finishing.
2166    pub fn finish_maintained(
2167        &mut self,
2168        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2169        group_size_hints: GroupSizeHints,
2170    ) {
2171        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2172            let old_finishing = mem::replace(
2173                finishing,
2174                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2175            );
2176            *self = HirRelationExpr::top_k(
2177                std::mem::replace(
2178                    self,
2179                    HirRelationExpr::Constant {
2180                        rows: vec![],
2181                        typ: SqlRelationType::new(Vec::new()),
2182                    },
2183                ),
2184                vec![],
2185                old_finishing.order_by,
2186                old_finishing.limit,
2187                old_finishing.offset,
2188                group_size_hints.limit_input_group_size,
2189            )
2190            .project(old_finishing.project);
2191        }
2192    }
2193
2194    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2195    ///
2196    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2197    /// parameter is not an HirScalarExpr anymore.)
2198    pub fn trivial_row_set_finishing_hir(
2199        arity: usize,
2200    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2201        RowSetFinishing {
2202            order_by: Vec::new(),
2203            limit: None,
2204            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2205            project: (0..arity).collect(),
2206        }
2207    }
2208
2209    /// True if the finishing does nothing to any result set.
2210    ///
2211    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2212    /// parameter is not an HirScalarExpr anymore.)
2213    pub fn is_trivial_row_set_finishing_hir(
2214        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2215        arity: usize,
2216    ) -> bool {
2217        rsf.limit.is_none()
2218            && rsf.order_by.is_empty()
2219            && rsf
2220                .offset
2221                .clone()
2222                .try_into_literal_int64()
2223                .is_ok_and(|o| o == 0)
2224            && rsf.project.iter().copied().eq(0..arity)
2225    }
2226
2227    /// The HirRelationExpr is considered potentially expensive if and only if
2228    /// at least one of the following conditions is true:
2229    ///
2230    ///  - It contains at least one CallTable or a Reduce operator.
2231    ///  - It contains at least one HirScalarExpr with a function call.
2232    ///
2233    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2234    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2235    pub fn could_run_expensive_function(&self) -> bool {
2236        let mut result = false;
2237        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2238            use HirRelationExpr::*;
2239            use HirScalarExpr::*;
2240
2241            self.visit_children(|scalar: &HirScalarExpr| {
2242                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2243                    result |= match scalar {
2244                        Column(..)
2245                        | Literal(..)
2246                        | CallUnmaterializable(..)
2247                        | If { .. }
2248                        | Parameter(..)
2249                        | Select(..)
2250                        | Exists(..) => false,
2251                        // Function calls are considered expensive
2252                        CallUnary { .. }
2253                        | CallBinary { .. }
2254                        | CallVariadic { .. }
2255                        | Windowing(..) => true,
2256                    };
2257                }) {
2258                    // Conservatively set `true` on RecursionLimitError.
2259                    result = true;
2260                }
2261            });
2262
2263            // CallTable has a table function; Reduce has an aggregate function.
2264            // Other constructs use MirScalarExpr to run a function
2265            result |= matches!(e, CallTable { .. } | Reduce { .. });
2266        }) {
2267            // Conservatively set `true` on RecursionLimitError.
2268            result = true;
2269        }
2270
2271        result
2272    }
2273
2274    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2275    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2276        let mut contains = false;
2277        self.visit_post(&mut |expr| {
2278            expr.visit_children(|expr: &HirScalarExpr| {
2279                contains = contains || expr.contains_temporal()
2280            })
2281        })?;
2282        Ok(contains)
2283    }
2284
2285    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2286    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2287        let mut contains = false;
2288        self.visit_post(&mut |expr| {
2289            expr.visit_children(|expr: &HirScalarExpr| {
2290                contains = contains || expr.contains_unmaterializable()
2291            })
2292        })?;
2293        Ok(contains)
2294    }
2295}
2296
2297impl CollectionPlan for HirRelationExpr {
2298    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2299    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2300    /// (It does explore inside subqueries.)
2301    ///
2302    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2303    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2304    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2305        if let Self::Get {
2306            id: Id::Global(id), ..
2307        } = self
2308        {
2309            out.insert(*id);
2310        }
2311        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2312    }
2313}
2314
2315impl VisitChildren<Self> for HirRelationExpr {
2316    fn visit_children<F>(&self, mut f: F)
2317    where
2318        F: FnMut(&Self),
2319    {
2320        // subqueries of type HirRelationExpr might be wrapped in
2321        // Exists or Select variants within HirScalarExpr trees
2322        // attached at the current node, and we want to visit them as well
2323        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2324            #[allow(deprecated)]
2325            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2326                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2327                    f(expr.as_ref())
2328                }
2329                _ => (),
2330            });
2331        });
2332
2333        use HirRelationExpr::*;
2334        match self {
2335            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2336            Let {
2337                name: _,
2338                id: _,
2339                value,
2340                body,
2341            } => {
2342                f(value);
2343                f(body);
2344            }
2345            LetRec {
2346                limit: _,
2347                bindings,
2348                body,
2349            } => {
2350                for (_, _, value, _) in bindings.iter() {
2351                    f(value);
2352                }
2353                f(body);
2354            }
2355            Project { input, outputs: _ } => f(input),
2356            Map { input, scalars: _ } => {
2357                f(input);
2358            }
2359            CallTable { func: _, exprs: _ } => (),
2360            Filter {
2361                input,
2362                predicates: _,
2363            } => {
2364                f(input);
2365            }
2366            Join {
2367                left,
2368                right,
2369                on: _,
2370                kind: _,
2371            } => {
2372                f(left);
2373                f(right);
2374            }
2375            Reduce {
2376                input,
2377                group_key: _,
2378                aggregates: _,
2379                expected_group_size: _,
2380            } => {
2381                f(input);
2382            }
2383            Distinct { input }
2384            | TopK {
2385                input,
2386                group_key: _,
2387                order_key: _,
2388                limit: _,
2389                offset: _,
2390                expected_group_size: _,
2391            }
2392            | Negate { input }
2393            | Threshold { input } => {
2394                f(input);
2395            }
2396            Union { base, inputs } => {
2397                f(base);
2398                for input in inputs {
2399                    f(input);
2400                }
2401            }
2402        }
2403    }
2404
2405    fn visit_mut_children<F>(&mut self, mut f: F)
2406    where
2407        F: FnMut(&mut Self),
2408    {
2409        // subqueries of type HirRelationExpr might be wrapped in
2410        // Exists or Select variants within HirScalarExpr trees
2411        // attached at the current node, and we want to visit them as well
2412        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2413            #[allow(deprecated)]
2414            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2415                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2416                    f(expr.as_mut())
2417                }
2418                _ => (),
2419            });
2420        });
2421
2422        use HirRelationExpr::*;
2423        match self {
2424            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2425            Let {
2426                name: _,
2427                id: _,
2428                value,
2429                body,
2430            } => {
2431                f(value);
2432                f(body);
2433            }
2434            LetRec {
2435                limit: _,
2436                bindings,
2437                body,
2438            } => {
2439                for (_, _, value, _) in bindings.iter_mut() {
2440                    f(value);
2441                }
2442                f(body);
2443            }
2444            Project { input, outputs: _ } => f(input),
2445            Map { input, scalars: _ } => {
2446                f(input);
2447            }
2448            CallTable { func: _, exprs: _ } => (),
2449            Filter {
2450                input,
2451                predicates: _,
2452            } => {
2453                f(input);
2454            }
2455            Join {
2456                left,
2457                right,
2458                on: _,
2459                kind: _,
2460            } => {
2461                f(left);
2462                f(right);
2463            }
2464            Reduce {
2465                input,
2466                group_key: _,
2467                aggregates: _,
2468                expected_group_size: _,
2469            } => {
2470                f(input);
2471            }
2472            Distinct { input }
2473            | TopK {
2474                input,
2475                group_key: _,
2476                order_key: _,
2477                limit: _,
2478                offset: _,
2479                expected_group_size: _,
2480            }
2481            | Negate { input }
2482            | Threshold { input } => {
2483                f(input);
2484            }
2485            Union { base, inputs } => {
2486                f(base);
2487                for input in inputs {
2488                    f(input);
2489                }
2490            }
2491        }
2492    }
2493
2494    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2495    where
2496        F: FnMut(&Self) -> Result<(), E>,
2497        E: From<RecursionLimitError>,
2498    {
2499        // subqueries of type HirRelationExpr might be wrapped in
2500        // Exists or Select variants within HirScalarExpr trees
2501        // attached at the current node, and we want to visit them as well
2502        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2503            Visit::try_visit_post(expr, &mut |expr| match expr {
2504                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2505                    f(expr.as_ref())
2506                }
2507                _ => Ok(()),
2508            })
2509        })?;
2510
2511        use HirRelationExpr::*;
2512        match self {
2513            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2514            Let {
2515                name: _,
2516                id: _,
2517                value,
2518                body,
2519            } => {
2520                f(value)?;
2521                f(body)?;
2522            }
2523            LetRec {
2524                limit: _,
2525                bindings,
2526                body,
2527            } => {
2528                for (_, _, value, _) in bindings.iter() {
2529                    f(value)?;
2530                }
2531                f(body)?;
2532            }
2533            Project { input, outputs: _ } => f(input)?,
2534            Map { input, scalars: _ } => {
2535                f(input)?;
2536            }
2537            CallTable { func: _, exprs: _ } => (),
2538            Filter {
2539                input,
2540                predicates: _,
2541            } => {
2542                f(input)?;
2543            }
2544            Join {
2545                left,
2546                right,
2547                on: _,
2548                kind: _,
2549            } => {
2550                f(left)?;
2551                f(right)?;
2552            }
2553            Reduce {
2554                input,
2555                group_key: _,
2556                aggregates: _,
2557                expected_group_size: _,
2558            } => {
2559                f(input)?;
2560            }
2561            Distinct { input }
2562            | TopK {
2563                input,
2564                group_key: _,
2565                order_key: _,
2566                limit: _,
2567                offset: _,
2568                expected_group_size: _,
2569            }
2570            | Negate { input }
2571            | Threshold { input } => {
2572                f(input)?;
2573            }
2574            Union { base, inputs } => {
2575                f(base)?;
2576                for input in inputs {
2577                    f(input)?;
2578                }
2579            }
2580        }
2581        Ok(())
2582    }
2583
2584    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2585    where
2586        F: FnMut(&mut Self) -> Result<(), E>,
2587        E: From<RecursionLimitError>,
2588    {
2589        // subqueries of type HirRelationExpr might be wrapped in
2590        // Exists or Select variants within HirScalarExpr trees
2591        // attached at the current node, and we want to visit them as well
2592        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2593            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2594                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2595                    f(expr.as_mut())
2596                }
2597                _ => Ok(()),
2598            })
2599        })?;
2600
2601        use HirRelationExpr::*;
2602        match self {
2603            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2604            Let {
2605                name: _,
2606                id: _,
2607                value,
2608                body,
2609            } => {
2610                f(value)?;
2611                f(body)?;
2612            }
2613            LetRec {
2614                limit: _,
2615                bindings,
2616                body,
2617            } => {
2618                for (_, _, value, _) in bindings.iter_mut() {
2619                    f(value)?;
2620                }
2621                f(body)?;
2622            }
2623            Project { input, outputs: _ } => f(input)?,
2624            Map { input, scalars: _ } => {
2625                f(input)?;
2626            }
2627            CallTable { func: _, exprs: _ } => (),
2628            Filter {
2629                input,
2630                predicates: _,
2631            } => {
2632                f(input)?;
2633            }
2634            Join {
2635                left,
2636                right,
2637                on: _,
2638                kind: _,
2639            } => {
2640                f(left)?;
2641                f(right)?;
2642            }
2643            Reduce {
2644                input,
2645                group_key: _,
2646                aggregates: _,
2647                expected_group_size: _,
2648            } => {
2649                f(input)?;
2650            }
2651            Distinct { input }
2652            | TopK {
2653                input,
2654                group_key: _,
2655                order_key: _,
2656                limit: _,
2657                offset: _,
2658                expected_group_size: _,
2659            }
2660            | Negate { input }
2661            | Threshold { input } => {
2662                f(input)?;
2663            }
2664            Union { base, inputs } => {
2665                f(base)?;
2666                for input in inputs {
2667                    f(input)?;
2668                }
2669            }
2670        }
2671        Ok(())
2672    }
2673}
2674
2675impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2676    fn visit_children<F>(&self, mut f: F)
2677    where
2678        F: FnMut(&HirScalarExpr),
2679    {
2680        use HirRelationExpr::*;
2681        match self {
2682            Constant { rows: _, typ: _ }
2683            | Get { id: _, typ: _ }
2684            | Let {
2685                name: _,
2686                id: _,
2687                value: _,
2688                body: _,
2689            }
2690            | LetRec {
2691                limit: _,
2692                bindings: _,
2693                body: _,
2694            }
2695            | Project {
2696                input: _,
2697                outputs: _,
2698            } => (),
2699            Map { input: _, scalars } => {
2700                for scalar in scalars {
2701                    f(scalar);
2702                }
2703            }
2704            CallTable { func: _, exprs } => {
2705                for expr in exprs {
2706                    f(expr);
2707                }
2708            }
2709            Filter {
2710                input: _,
2711                predicates,
2712            } => {
2713                for predicate in predicates {
2714                    f(predicate);
2715                }
2716            }
2717            Join {
2718                left: _,
2719                right: _,
2720                on,
2721                kind: _,
2722            } => f(on),
2723            Reduce {
2724                input: _,
2725                group_key: _,
2726                aggregates,
2727                expected_group_size: _,
2728            } => {
2729                for aggregate in aggregates {
2730                    f(aggregate.expr.as_ref());
2731                }
2732            }
2733            TopK {
2734                input: _,
2735                group_key: _,
2736                order_key: _,
2737                limit,
2738                offset,
2739                expected_group_size: _,
2740            } => {
2741                if let Some(limit) = limit {
2742                    f(limit)
2743                }
2744                f(offset)
2745            }
2746            Distinct { input: _ }
2747            | Negate { input: _ }
2748            | Threshold { input: _ }
2749            | Union { base: _, inputs: _ } => (),
2750        }
2751    }
2752
2753    fn visit_mut_children<F>(&mut self, mut f: F)
2754    where
2755        F: FnMut(&mut HirScalarExpr),
2756    {
2757        use HirRelationExpr::*;
2758        match self {
2759            Constant { rows: _, typ: _ }
2760            | Get { id: _, typ: _ }
2761            | Let {
2762                name: _,
2763                id: _,
2764                value: _,
2765                body: _,
2766            }
2767            | LetRec {
2768                limit: _,
2769                bindings: _,
2770                body: _,
2771            }
2772            | Project {
2773                input: _,
2774                outputs: _,
2775            } => (),
2776            Map { input: _, scalars } => {
2777                for scalar in scalars {
2778                    f(scalar);
2779                }
2780            }
2781            CallTable { func: _, exprs } => {
2782                for expr in exprs {
2783                    f(expr);
2784                }
2785            }
2786            Filter {
2787                input: _,
2788                predicates,
2789            } => {
2790                for predicate in predicates {
2791                    f(predicate);
2792                }
2793            }
2794            Join {
2795                left: _,
2796                right: _,
2797                on,
2798                kind: _,
2799            } => f(on),
2800            Reduce {
2801                input: _,
2802                group_key: _,
2803                aggregates,
2804                expected_group_size: _,
2805            } => {
2806                for aggregate in aggregates {
2807                    f(aggregate.expr.as_mut());
2808                }
2809            }
2810            TopK {
2811                input: _,
2812                group_key: _,
2813                order_key: _,
2814                limit,
2815                offset,
2816                expected_group_size: _,
2817            } => {
2818                if let Some(limit) = limit {
2819                    f(limit)
2820                }
2821                f(offset)
2822            }
2823            Distinct { input: _ }
2824            | Negate { input: _ }
2825            | Threshold { input: _ }
2826            | Union { base: _, inputs: _ } => (),
2827        }
2828    }
2829
2830    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2831    where
2832        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2833        E: From<RecursionLimitError>,
2834    {
2835        use HirRelationExpr::*;
2836        match self {
2837            Constant { rows: _, typ: _ }
2838            | Get { id: _, typ: _ }
2839            | Let {
2840                name: _,
2841                id: _,
2842                value: _,
2843                body: _,
2844            }
2845            | LetRec {
2846                limit: _,
2847                bindings: _,
2848                body: _,
2849            }
2850            | Project {
2851                input: _,
2852                outputs: _,
2853            } => (),
2854            Map { input: _, scalars } => {
2855                for scalar in scalars {
2856                    f(scalar)?;
2857                }
2858            }
2859            CallTable { func: _, exprs } => {
2860                for expr in exprs {
2861                    f(expr)?;
2862                }
2863            }
2864            Filter {
2865                input: _,
2866                predicates,
2867            } => {
2868                for predicate in predicates {
2869                    f(predicate)?;
2870                }
2871            }
2872            Join {
2873                left: _,
2874                right: _,
2875                on,
2876                kind: _,
2877            } => f(on)?,
2878            Reduce {
2879                input: _,
2880                group_key: _,
2881                aggregates,
2882                expected_group_size: _,
2883            } => {
2884                for aggregate in aggregates {
2885                    f(aggregate.expr.as_ref())?;
2886                }
2887            }
2888            TopK {
2889                input: _,
2890                group_key: _,
2891                order_key: _,
2892                limit,
2893                offset,
2894                expected_group_size: _,
2895            } => {
2896                if let Some(limit) = limit {
2897                    f(limit)?
2898                }
2899                f(offset)?
2900            }
2901            Distinct { input: _ }
2902            | Negate { input: _ }
2903            | Threshold { input: _ }
2904            | Union { base: _, inputs: _ } => (),
2905        }
2906        Ok(())
2907    }
2908
2909    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2910    where
2911        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2912        E: From<RecursionLimitError>,
2913    {
2914        use HirRelationExpr::*;
2915        match self {
2916            Constant { rows: _, typ: _ }
2917            | Get { id: _, typ: _ }
2918            | Let {
2919                name: _,
2920                id: _,
2921                value: _,
2922                body: _,
2923            }
2924            | LetRec {
2925                limit: _,
2926                bindings: _,
2927                body: _,
2928            }
2929            | Project {
2930                input: _,
2931                outputs: _,
2932            } => (),
2933            Map { input: _, scalars } => {
2934                for scalar in scalars {
2935                    f(scalar)?;
2936                }
2937            }
2938            CallTable { func: _, exprs } => {
2939                for expr in exprs {
2940                    f(expr)?;
2941                }
2942            }
2943            Filter {
2944                input: _,
2945                predicates,
2946            } => {
2947                for predicate in predicates {
2948                    f(predicate)?;
2949                }
2950            }
2951            Join {
2952                left: _,
2953                right: _,
2954                on,
2955                kind: _,
2956            } => f(on)?,
2957            Reduce {
2958                input: _,
2959                group_key: _,
2960                aggregates,
2961                expected_group_size: _,
2962            } => {
2963                for aggregate in aggregates {
2964                    f(aggregate.expr.as_mut())?;
2965                }
2966            }
2967            TopK {
2968                input: _,
2969                group_key: _,
2970                order_key: _,
2971                limit,
2972                offset,
2973                expected_group_size: _,
2974            } => {
2975                if let Some(limit) = limit {
2976                    f(limit)?
2977                }
2978                f(offset)?
2979            }
2980            Distinct { input: _ }
2981            | Negate { input: _ }
2982            | Threshold { input: _ }
2983            | Union { base: _, inputs: _ } => (),
2984        }
2985        Ok(())
2986    }
2987}
2988
2989impl HirScalarExpr {
2990    pub fn name(&self) -> Option<Arc<str>> {
2991        use HirScalarExpr::*;
2992        match self {
2993            Column(_, name)
2994            | Parameter(_, name)
2995            | Literal(_, _, name)
2996            | CallUnmaterializable(_, name)
2997            | CallUnary { name, .. }
2998            | CallBinary { name, .. }
2999            | CallVariadic { name, .. }
3000            | If { name, .. }
3001            | Exists(_, name)
3002            | Select(_, name)
3003            | Windowing(_, name) => name.0.clone(),
3004        }
3005    }
3006
3007    /// Replaces any parameter references in the expression with the
3008    /// corresponding datum in `params`.
3009    pub fn bind_parameters(
3010        &mut self,
3011        scx: &StatementContext,
3012        lifetime: QueryLifetime,
3013        params: &Params,
3014    ) -> Result<(), PlanError> {
3015        #[allow(deprecated)]
3016        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3017            if let HirScalarExpr::Parameter(n, name) = e {
3018                let datum = match params.datums.iter().nth(*n - 1) {
3019                    None => return Err(PlanError::UnknownParameter(*n)),
3020                    Some(datum) => datum,
3021                };
3022                let scalar_type = &params.execute_types[*n - 1];
3023                let row = Row::pack([datum]);
3024                let column_type = scalar_type.clone().nullable(datum.is_null());
3025
3026                let name = if let Some(name) = &name.0 {
3027                    Some(Arc::clone(name))
3028                } else {
3029                    Some(Arc::from(format!("${n}")))
3030                };
3031
3032                let qcx = QueryContext::root(scx, lifetime);
3033                let ecx = execute_expr_context(&qcx);
3034
3035                *e = plan_cast(
3036                    &ecx,
3037                    *EXECUTE_CAST_CONTEXT,
3038                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3039                    &params.expected_types[*n - 1],
3040                )
3041                .expect("checked in plan_params");
3042            }
3043            Ok(())
3044        })
3045    }
3046
3047    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3048    /// replaced with the corresponding expression fragment from `params` rather
3049    /// than a datum.
3050    ///
3051    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3052    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3053    /// in `self` that refer to invalid indices of `params` will cause a panic.
3054    ///
3055    /// Column references in parameters will be corrected to account for the
3056    /// depth at which they are spliced.
3057    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3058        #[allow(deprecated)]
3059        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3060                                                        e: &mut HirScalarExpr|
3061         -> Result<(), ()> {
3062            if let HirScalarExpr::Parameter(i, _name) = e {
3063                *e = params[*i - 1].clone();
3064                // Correct any column references in the parameter expression for
3065                // its new depth.
3066                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3067                    if col.level >= d {
3068                        col.level += depth
3069                    }
3070                });
3071            }
3072            Ok(())
3073        });
3074    }
3075
3076    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3077    pub fn contains_temporal(&self) -> bool {
3078        let mut contains = false;
3079        #[allow(deprecated)]
3080        self.visit_post_nolimit(&mut |e| {
3081            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3082                contains = true;
3083            }
3084        });
3085        contains
3086    }
3087
3088    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3089    pub fn contains_unmaterializable(&self) -> bool {
3090        let mut contains = false;
3091        #[allow(deprecated)]
3092        self.visit_post_nolimit(&mut |e| {
3093            if let Self::CallUnmaterializable(_, _) = e {
3094                contains = true;
3095            }
3096        });
3097        contains
3098    }
3099
3100    /// Constructs an unnamed column reference in the current scope.
3101    /// Use [`HirScalarExpr::named_column`] when a name is known.
3102    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3103    pub fn column(index: usize) -> HirScalarExpr {
3104        HirScalarExpr::Column(
3105            ColumnRef {
3106                level: 0,
3107                column: index,
3108            },
3109            TreatAsEqual(None),
3110        )
3111    }
3112
3113    /// Constructs an unnamed column reference.
3114    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3115        HirScalarExpr::Column(cr, TreatAsEqual(None))
3116    }
3117
3118    /// Constructs a named column reference.
3119    /// Names are interned by a `NameManager`.
3120    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3121        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3122    }
3123
3124    pub fn parameter(n: usize) -> HirScalarExpr {
3125        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3126    }
3127
3128    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3129        let col_type = scalar_type.nullable(datum.is_null());
3130        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3131        let row = Row::pack([datum]);
3132        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3133    }
3134
3135    pub fn literal_true() -> HirScalarExpr {
3136        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3137    }
3138
3139    pub fn literal_false() -> HirScalarExpr {
3140        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3141    }
3142
3143    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3144        HirScalarExpr::literal(Datum::Null, scalar_type)
3145    }
3146
3147    pub fn literal_1d_array(
3148        datums: Vec<Datum>,
3149        element_scalar_type: SqlScalarType,
3150    ) -> Result<HirScalarExpr, PlanError> {
3151        let scalar_type = match element_scalar_type {
3152            SqlScalarType::Array(_) => {
3153                sql_bail!("cannot build array from array type");
3154            }
3155            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3156        };
3157
3158        let mut row = Row::default();
3159        row.packer()
3160            .try_push_array(
3161                &[ArrayDimension {
3162                    lower_bound: 1,
3163                    length: datums.len(),
3164                }],
3165                datums,
3166            )
3167            .expect("array constructed to be valid");
3168
3169        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3170    }
3171
3172    pub fn as_literal(&self) -> Option<Datum<'_>> {
3173        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3174            Some(row.unpack_first())
3175        } else {
3176            None
3177        }
3178    }
3179
3180    pub fn is_literal_true(&self) -> bool {
3181        Some(Datum::True) == self.as_literal()
3182    }
3183
3184    pub fn is_literal_false(&self) -> bool {
3185        Some(Datum::False) == self.as_literal()
3186    }
3187
3188    pub fn is_literal_null(&self) -> bool {
3189        Some(Datum::Null) == self.as_literal()
3190    }
3191
3192    /// Return true iff `self` consists only of literals, materializable function calls, and
3193    /// if-else statements.
3194    pub fn is_constant(&self) -> bool {
3195        let mut worklist = vec![self];
3196        while let Some(expr) = worklist.pop() {
3197            match expr {
3198                Self::Literal(..) => {
3199                    // leaf node, do nothing
3200                }
3201                Self::CallUnary { expr, .. } => {
3202                    worklist.push(expr);
3203                }
3204                Self::CallBinary {
3205                    func: _,
3206                    expr1,
3207                    expr2,
3208                    name: _,
3209                } => {
3210                    worklist.push(expr1);
3211                    worklist.push(expr2);
3212                }
3213                Self::CallVariadic {
3214                    func: _,
3215                    exprs,
3216                    name: _,
3217                } => {
3218                    worklist.extend(exprs.iter());
3219                }
3220                // (CallUnmaterializable is not allowed)
3221                Self::If {
3222                    cond,
3223                    then,
3224                    els,
3225                    name: _,
3226                } => {
3227                    worklist.push(cond);
3228                    worklist.push(then);
3229                    worklist.push(els);
3230                }
3231                _ => {
3232                    return false; // Any other node makes `self` non-constant.
3233                }
3234            }
3235        }
3236        true
3237    }
3238
3239    pub fn call_unary(self, func: UnaryFunc) -> Self {
3240        HirScalarExpr::CallUnary {
3241            func,
3242            expr: Box::new(self),
3243            name: NameMetadata::default(),
3244        }
3245    }
3246
3247    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3248        HirScalarExpr::CallBinary {
3249            func: func.into(),
3250            expr1: Box::new(self),
3251            expr2: Box::new(other),
3252            name: NameMetadata::default(),
3253        }
3254    }
3255
3256    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3257        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3258    }
3259
3260    pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3261        HirScalarExpr::CallVariadic {
3262            func,
3263            exprs,
3264            name: NameMetadata::default(),
3265        }
3266    }
3267
3268    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3269        HirScalarExpr::If {
3270            cond: Box::new(cond),
3271            then: Box::new(then),
3272            els: Box::new(els),
3273            name: NameMetadata::default(),
3274        }
3275    }
3276
3277    pub fn windowing(expr: WindowExpr) -> Self {
3278        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3279    }
3280
3281    pub fn or(self, other: Self) -> Self {
3282        HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3283    }
3284
3285    pub fn and(self, other: Self) -> Self {
3286        HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3287    }
3288
3289    pub fn not(self) -> Self {
3290        self.call_unary(UnaryFunc::Not(func::Not))
3291    }
3292
3293    pub fn call_is_null(self) -> Self {
3294        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3295    }
3296
3297    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3298    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3299        match args.len() {
3300            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3301            1 => args.swap_remove(0),
3302            _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3303        }
3304    }
3305
3306    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3307    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3308        match args.len() {
3309            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3310            1 => args.swap_remove(0),
3311            _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3312        }
3313    }
3314
3315    pub fn take(&mut self) -> Self {
3316        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3317    }
3318
3319    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3320    /// Visits the column references in this scalar expression.
3321    ///
3322    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3323    /// which will be incremented with each subquery entered and presented to the supplied
3324    /// function `f`.
3325    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3326    where
3327        F: FnMut(usize, &ColumnRef),
3328    {
3329        #[allow(deprecated)]
3330        let _ = self.visit_recursively(depth, &mut |depth: usize,
3331                                                    e: &HirScalarExpr|
3332         -> Result<(), ()> {
3333            if let HirScalarExpr::Column(col, _name) = e {
3334                f(depth, col)
3335            }
3336            Ok(())
3337        });
3338    }
3339
3340    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3341    /// Like `visit_columns`, but permits mutating the column references.
3342    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3343    where
3344        F: FnMut(usize, &mut ColumnRef),
3345    {
3346        #[allow(deprecated)]
3347        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3348                                                        e: &mut HirScalarExpr|
3349         -> Result<(), ()> {
3350            if let HirScalarExpr::Column(col, _name) = e {
3351                f(depth, col)
3352            }
3353            Ok(())
3354        });
3355    }
3356
3357    /// Visits those column references in this scalar expression that refer to the root
3358    /// level. These include column references that are at the root level, as well as column
3359    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3360    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3361    /// "root level" to be `self`'s level.)
3362    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3363    where
3364        F: FnMut(usize),
3365    {
3366        #[allow(deprecated)]
3367        let _ = self.visit_recursively(0, &mut |depth: usize,
3368                                                e: &HirScalarExpr|
3369         -> Result<(), ()> {
3370            if let HirScalarExpr::Column(col, _name) = e {
3371                if col.level == depth {
3372                    f(col.column)
3373                }
3374            }
3375            Ok(())
3376        });
3377    }
3378
3379    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3380    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3381    where
3382        F: FnMut(&mut usize),
3383    {
3384        #[allow(deprecated)]
3385        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3386                                                    e: &mut HirScalarExpr|
3387         -> Result<(), ()> {
3388            if let HirScalarExpr::Column(col, _name) = e {
3389                if col.level == depth {
3390                    f(&mut col.column)
3391                }
3392            }
3393            Ok(())
3394        });
3395    }
3396
3397    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3398    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3399    /// in them. It takes the current depth of the expression and increases it when
3400    /// entering a subquery.
3401    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3402    where
3403        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3404    {
3405        match self {
3406            HirScalarExpr::Literal(..)
3407            | HirScalarExpr::Parameter(..)
3408            | HirScalarExpr::CallUnmaterializable(..)
3409            | HirScalarExpr::Column(..) => (),
3410            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3411            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3412                expr1.visit_recursively(depth, f)?;
3413                expr2.visit_recursively(depth, f)?;
3414            }
3415            HirScalarExpr::CallVariadic { exprs, .. } => {
3416                for expr in exprs {
3417                    expr.visit_recursively(depth, f)?;
3418                }
3419            }
3420            HirScalarExpr::If {
3421                cond,
3422                then,
3423                els,
3424                name: _,
3425            } => {
3426                cond.visit_recursively(depth, f)?;
3427                then.visit_recursively(depth, f)?;
3428                els.visit_recursively(depth, f)?;
3429            }
3430            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3431                #[allow(deprecated)]
3432                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3433                    e.visit_recursively(depth, f)
3434                })?;
3435            }
3436            HirScalarExpr::Windowing(expr, _name) => {
3437                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3438            }
3439        }
3440        f(depth, self)
3441    }
3442
3443    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3444    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3445    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3446    where
3447        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3448    {
3449        match self {
3450            HirScalarExpr::Literal(..)
3451            | HirScalarExpr::Parameter(..)
3452            | HirScalarExpr::CallUnmaterializable(..)
3453            | HirScalarExpr::Column(..) => (),
3454            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3455            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3456                expr1.visit_recursively_mut(depth, f)?;
3457                expr2.visit_recursively_mut(depth, f)?;
3458            }
3459            HirScalarExpr::CallVariadic { exprs, .. } => {
3460                for expr in exprs {
3461                    expr.visit_recursively_mut(depth, f)?;
3462                }
3463            }
3464            HirScalarExpr::If {
3465                cond,
3466                then,
3467                els,
3468                name: _,
3469            } => {
3470                cond.visit_recursively_mut(depth, f)?;
3471                then.visit_recursively_mut(depth, f)?;
3472                els.visit_recursively_mut(depth, f)?;
3473            }
3474            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3475                #[allow(deprecated)]
3476                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3477                    e.visit_recursively_mut(depth, f)
3478                })?;
3479            }
3480            HirScalarExpr::Windowing(expr, _name) => {
3481                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3482            }
3483        }
3484        f(depth, self)
3485    }
3486
3487    /// Attempts to simplify self into a literal.
3488    ///
3489    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3490    /// an evaluation error occurs during simplification, or if self contains
3491    /// - a subquery
3492    /// - a column reference to an outer level
3493    /// - a parameter
3494    /// - a window function call
3495    fn simplify_to_literal(self) -> Option<Row> {
3496        let mut expr = self
3497            .lower_uncorrelated(crate::plan::lowering::Config::default())
3498            .ok()?;
3499        expr.reduce(&[]);
3500        match expr {
3501            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3502            _ => None,
3503        }
3504    }
3505
3506    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3507    /// or an evaluation error occurs during simplification), it returns
3508    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3509    ///
3510    /// The returned error is an _internal_ error if the expression contains
3511    /// - a subquery
3512    /// - a column reference to an outer level
3513    /// - a parameter
3514    /// - a window function call
3515    ///
3516    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3517    /// msg.
3518    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3519        let mut expr = self
3520            .lower_uncorrelated(crate::plan::lowering::Config::default())
3521            .map_err(|err| {
3522                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3523            })?;
3524        expr.reduce(&[]);
3525        match expr {
3526            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3527            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3528                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3529            ),
3530            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3531                "Not a constant".to_string(),
3532            )),
3533        }
3534    }
3535
3536    /// Attempts to simplify this expression to a literal 64-bit integer.
3537    ///
3538    /// Returns `None` if this expression cannot be simplified, e.g. because it
3539    /// contains non-literal values.
3540    ///
3541    /// # Panics
3542    ///
3543    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3544    pub fn into_literal_int64(self) -> Option<i64> {
3545        self.simplify_to_literal().and_then(|row| {
3546            let datum = row.unpack_first();
3547            if datum.is_null() {
3548                None
3549            } else {
3550                Some(datum.unwrap_int64())
3551            }
3552        })
3553    }
3554
3555    /// Attempts to simplify this expression to a literal string.
3556    ///
3557    /// Returns `None` if this expression cannot be simplified, e.g. because it
3558    /// contains non-literal values.
3559    ///
3560    /// # Panics
3561    ///
3562    /// Panics if this expression does not have type [`SqlScalarType::String`].
3563    pub fn into_literal_string(self) -> Option<String> {
3564        self.simplify_to_literal().and_then(|row| {
3565            let datum = row.unpack_first();
3566            if datum.is_null() {
3567                None
3568            } else {
3569                Some(datum.unwrap_str().to_owned())
3570            }
3571        })
3572    }
3573
3574    /// Attempts to simplify this expression to a literal MzTimestamp.
3575    ///
3576    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3577    /// simplified, e.g. because it contains non-literal values or a cast fails.
3578    ///
3579    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3580    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3581    /// See `try_into_literal_int64` as an example.
3582    ///
3583    /// # Panics
3584    ///
3585    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3586    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3587        self.simplify_to_literal().and_then(|row| {
3588            let datum = row.unpack_first();
3589            if datum.is_null() {
3590                None
3591            } else {
3592                Some(datum.unwrap_mz_timestamp())
3593            }
3594        })
3595    }
3596
3597    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3598    /// returns it as an i64.
3599    ///
3600    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3601    /// - it's not a constant expression (as determined by `is_constant`)
3602    /// - evaluates to null
3603    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3604    ///
3605    /// # Panics
3606    ///
3607    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3608    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3609        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3610        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3611        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3612        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3613        // preconditions also of all the other into_literal_... functions.)
3614        if !self.is_constant() {
3615            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3616                "Expected a constant expression, got {}",
3617                self
3618            )));
3619        }
3620        self.clone()
3621            .simplify_to_literal_with_result()
3622            .and_then(|row| {
3623                let datum = row.unpack_first();
3624                if datum.is_null() {
3625                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3626                        "Expected an expression that evaluates to a non-null value, got {}",
3627                        self
3628                    )))
3629                } else {
3630                    Ok(datum.unwrap_int64())
3631                }
3632            })
3633    }
3634
3635    pub fn contains_parameters(&self) -> bool {
3636        let mut contains_parameters = false;
3637        #[allow(deprecated)]
3638        let _ = self.visit_recursively(0, &mut |_depth: usize,
3639                                                expr: &HirScalarExpr|
3640         -> Result<(), ()> {
3641            if let HirScalarExpr::Parameter(..) = expr {
3642                contains_parameters = true;
3643            }
3644            Ok(())
3645        });
3646        contains_parameters
3647    }
3648}
3649
3650impl VisitChildren<Self> for HirScalarExpr {
3651    fn visit_children<F>(&self, mut f: F)
3652    where
3653        F: FnMut(&Self),
3654    {
3655        use HirScalarExpr::*;
3656        match self {
3657            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3658            CallUnary { expr, .. } => f(expr),
3659            CallBinary { expr1, expr2, .. } => {
3660                f(expr1);
3661                f(expr2);
3662            }
3663            CallVariadic { exprs, .. } => {
3664                for expr in exprs {
3665                    f(expr);
3666                }
3667            }
3668            If {
3669                cond,
3670                then,
3671                els,
3672                name: _,
3673            } => {
3674                f(cond);
3675                f(then);
3676                f(els);
3677            }
3678            Exists(..) | Select(..) => (),
3679            Windowing(expr, _name) => expr.visit_children(f),
3680        }
3681    }
3682
3683    fn visit_mut_children<F>(&mut self, mut f: F)
3684    where
3685        F: FnMut(&mut Self),
3686    {
3687        use HirScalarExpr::*;
3688        match self {
3689            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3690            CallUnary { expr, .. } => f(expr),
3691            CallBinary { expr1, expr2, .. } => {
3692                f(expr1);
3693                f(expr2);
3694            }
3695            CallVariadic { exprs, .. } => {
3696                for expr in exprs {
3697                    f(expr);
3698                }
3699            }
3700            If {
3701                cond,
3702                then,
3703                els,
3704                name: _,
3705            } => {
3706                f(cond);
3707                f(then);
3708                f(els);
3709            }
3710            Exists(..) | Select(..) => (),
3711            Windowing(expr, _name) => expr.visit_mut_children(f),
3712        }
3713    }
3714
3715    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3716    where
3717        F: FnMut(&Self) -> Result<(), E>,
3718        E: From<RecursionLimitError>,
3719    {
3720        use HirScalarExpr::*;
3721        match self {
3722            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3723            CallUnary { expr, .. } => f(expr)?,
3724            CallBinary { expr1, expr2, .. } => {
3725                f(expr1)?;
3726                f(expr2)?;
3727            }
3728            CallVariadic { exprs, .. } => {
3729                for expr in exprs {
3730                    f(expr)?;
3731                }
3732            }
3733            If {
3734                cond,
3735                then,
3736                els,
3737                name: _,
3738            } => {
3739                f(cond)?;
3740                f(then)?;
3741                f(els)?;
3742            }
3743            Exists(..) | Select(..) => (),
3744            Windowing(expr, _name) => expr.try_visit_children(f)?,
3745        }
3746        Ok(())
3747    }
3748
3749    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3750    where
3751        F: FnMut(&mut Self) -> Result<(), E>,
3752        E: From<RecursionLimitError>,
3753    {
3754        use HirScalarExpr::*;
3755        match self {
3756            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3757            CallUnary { expr, .. } => f(expr)?,
3758            CallBinary { expr1, expr2, .. } => {
3759                f(expr1)?;
3760                f(expr2)?;
3761            }
3762            CallVariadic { exprs, .. } => {
3763                for expr in exprs {
3764                    f(expr)?;
3765                }
3766            }
3767            If {
3768                cond,
3769                then,
3770                els,
3771                name: _,
3772            } => {
3773                f(cond)?;
3774                f(then)?;
3775                f(els)?;
3776            }
3777            Exists(..) | Select(..) => (),
3778            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3779        }
3780        Ok(())
3781    }
3782}
3783
3784impl AbstractExpr for HirScalarExpr {
3785    type Type = SqlColumnType;
3786
3787    fn typ(
3788        &self,
3789        outers: &[SqlRelationType],
3790        inner: &SqlRelationType,
3791        params: &BTreeMap<usize, SqlScalarType>,
3792    ) -> Self::Type {
3793        stack::maybe_grow(|| match self {
3794            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3795                if *level == 0 {
3796                    inner.column_types[*column].clone()
3797                } else {
3798                    outers[*level - 1].column_types[*column].clone()
3799                }
3800            }
3801            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3802            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3803            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3804            HirScalarExpr::CallUnary {
3805                expr,
3806                func,
3807                name: _,
3808            } => func.output_type(expr.typ(outers, inner, params)),
3809            HirScalarExpr::CallBinary {
3810                expr1,
3811                expr2,
3812                func,
3813                name: _,
3814            } => func.output_type(
3815                expr1.typ(outers, inner, params),
3816                expr2.typ(outers, inner, params),
3817            ),
3818            HirScalarExpr::CallVariadic {
3819                exprs,
3820                func,
3821                name: _,
3822            } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3823            HirScalarExpr::If {
3824                cond: _,
3825                then,
3826                els,
3827                name: _,
3828            } => {
3829                let then_type = then.typ(outers, inner, params);
3830                let else_type = els.typ(outers, inner, params);
3831                then_type.union(&else_type).unwrap()
3832            }
3833            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3834            HirScalarExpr::Select(expr, _name) => {
3835                let mut outers = outers.to_vec();
3836                outers.insert(0, inner.clone());
3837                expr.typ(&outers, params)
3838                    .column_types
3839                    .into_element()
3840                    .nullable(true)
3841            }
3842            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3843        })
3844    }
3845}
3846
3847impl AggregateExpr {
3848    pub fn typ(
3849        &self,
3850        outers: &[SqlRelationType],
3851        inner: &SqlRelationType,
3852        params: &BTreeMap<usize, SqlScalarType>,
3853    ) -> SqlColumnType {
3854        self.func.output_type(self.expr.typ(outers, inner, params))
3855    }
3856
3857    /// Returns whether the expression is COUNT(*) or not.  Note that
3858    /// when we define the count builtin in sql::func, we convert
3859    /// COUNT(*) to COUNT(true), making it indistinguishable from
3860    /// literal COUNT(true), but we prefer to consider this as the
3861    /// former.
3862    ///
3863    /// (MIR has the same `is_count_asterisk`.)
3864    pub fn is_count_asterisk(&self) -> bool {
3865        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3866    }
3867}