mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50    type Relation = HirRelationExpr;
51    type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56        if *all {
57            let rhs = rhs.negate();
58            HirRelationExpr::union(lhs, rhs).threshold()
59        } else {
60            let lhs = lhs.distinct();
61            let rhs = rhs.distinct().negate();
62            HirRelationExpr::union(lhs, rhs).threshold()
63        }
64    }
65
66    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67        let mut result = None;
68
69        use HirRelationExpr::*;
70        if let Threshold { input } = expr {
71            if let Union { base: lhs, inputs } = input.as_ref() {
72                if let [rhs] = &inputs[..] {
73                    if let Negate { input: rhs } = rhs {
74                        match (lhs.as_ref(), rhs.as_ref()) {
75                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
76                                let all = false;
77                                let lhs = lhs.as_ref();
78                                let rhs = rhs.as_ref();
79                                result = Some(Except { all, lhs, rhs })
80                            }
81                            (lhs, rhs) => {
82                                let all = true;
83                                result = Some(Except { all, lhs, rhs })
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        result
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
97pub enum HirRelationExpr {
98    Constant {
99        rows: Vec<Row>,
100        typ: SqlRelationType,
101    },
102    Get {
103        id: mz_expr::Id,
104        typ: SqlRelationType,
105    },
106    /// Mutually recursive CTE
107    LetRec {
108        /// Maximum number of iterations to evaluate. If None, then there is no limit.
109        limit: Option<LetRecLimit>,
110        /// List of bindings all of which are in scope of each other.
111        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
112        /// Result of the AST node.
113        body: Box<HirRelationExpr>,
114    },
115    /// CTE
116    Let {
117        name: String,
118        /// The identifier to be used in `Get` variants to retrieve `value`.
119        id: mz_expr::LocalId,
120        /// The collection to be bound to `name`.
121        value: Box<HirRelationExpr>,
122        /// The result of the `Let`, evaluated with `name` bound to `value`.
123        body: Box<HirRelationExpr>,
124    },
125    Project {
126        input: Box<HirRelationExpr>,
127        outputs: Vec<usize>,
128    },
129    Map {
130        input: Box<HirRelationExpr>,
131        scalars: Vec<HirScalarExpr>,
132    },
133    CallTable {
134        func: TableFunc,
135        exprs: Vec<HirScalarExpr>,
136    },
137    Filter {
138        input: Box<HirRelationExpr>,
139        predicates: Vec<HirScalarExpr>,
140    },
141    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
142    /// joins away into more primitive exprs
143    Join {
144        left: Box<HirRelationExpr>,
145        right: Box<HirRelationExpr>,
146        on: HirScalarExpr,
147        kind: JoinKind,
148    },
149    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
150    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
151    /// rows
152    Reduce {
153        input: Box<HirRelationExpr>,
154        group_key: Vec<usize>,
155        aggregates: Vec<AggregateExpr>,
156        expected_group_size: Option<u64>,
157    },
158    Distinct {
159        input: Box<HirRelationExpr>,
160    },
161    /// Groups and orders within each group, limiting output.
162    TopK {
163        /// The source collection.
164        input: Box<HirRelationExpr>,
165        /// Column indices used to form groups.
166        group_key: Vec<usize>,
167        /// Column indices used to order rows within groups.
168        order_key: Vec<ColumnOrder>,
169        /// Number of records to retain.
170        /// It is of SqlScalarType::Int64.
171        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
172        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
173        /// is better for Postgres compat. This is because if there is a $1 here, then when external
174        /// tools `describe` the prepared statement, they discover this type. If what they find
175        /// were UInt64, then they might have trouble calling the prepared statement, because the
176        /// unsigned types are non-standard, and also don't exist even in Postgres.)
177        limit: Option<HirScalarExpr>,
178        /// Number of records to skip.
179        /// It is of SqlScalarType::Int64.
180        /// This can contain parameters at first, but by the time we reach lowering, this should
181        /// already be simply a Literal.
182        offset: HirScalarExpr,
183        /// User-supplied hint: how many rows will have the same group key.
184        expected_group_size: Option<u64>,
185    },
186    Negate {
187        input: Box<HirRelationExpr>,
188    },
189    /// Keep rows from a dataflow where the row counts are positive.
190    Threshold {
191        input: Box<HirRelationExpr>,
192    },
193    Union {
194        base: Box<HirRelationExpr>,
195        inputs: Vec<HirRelationExpr>,
196    },
197}
198
199/// Stored column metadata.
200pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
204pub enum HirScalarExpr {
205    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
206    /// variable could refer to a column of the current input, or to a column of an outer relation.
207    /// We use ColumnRef to denote the difference.
208    Column(ColumnRef, NameMetadata),
209    Parameter(usize, NameMetadata),
210    Literal(Row, SqlColumnType, NameMetadata),
211    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212    CallUnary {
213        func: UnaryFunc,
214        expr: Box<HirScalarExpr>,
215        name: NameMetadata,
216    },
217    CallBinary {
218        func: BinaryFunc,
219        expr1: Box<HirScalarExpr>,
220        expr2: Box<HirScalarExpr>,
221        name: NameMetadata,
222    },
223    CallVariadic {
224        func: VariadicFunc,
225        exprs: Vec<HirScalarExpr>,
226        name: NameMetadata,
227    },
228    If {
229        cond: Box<HirScalarExpr>,
230        then: Box<HirScalarExpr>,
231        els: Box<HirScalarExpr>,
232        name: NameMetadata,
233    },
234    /// Returns true if `expr` returns any rows
235    Exists(Box<HirRelationExpr>, NameMetadata),
236    /// Given `expr` with arity 1. If expr returns:
237    /// * 0 rows, return NULL
238    /// * 1 row, return the value of that row
239    /// * >1 rows, we return an error
240    Select(Box<HirRelationExpr>, NameMetadata),
241    Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245/// Represents the invocation of a window function over an optional partitioning with an optional
246/// order.
247pub struct WindowExpr {
248    pub func: WindowExprType,
249    pub partition_by: Vec<HirScalarExpr>,
250    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
251    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
252    ///    above,
253    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
254    /// These are separated because they are used in different places: the outer `order_by` is used
255    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
256    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
257    /// (`WindowExpr` exists only in HIR, but not in MIR.)
258    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
259    /// lowering, and not to original input columns.
260    pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265    where
266        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267    {
268        #[allow(deprecated)]
269        self.func.visit_expressions(f)?;
270        for expr in self.partition_by.iter() {
271            f(expr)?;
272        }
273        for expr in self.order_by.iter() {
274            f(expr)?;
275        }
276        Ok(())
277    }
278
279    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280    where
281        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282    {
283        #[allow(deprecated)]
284        self.func.visit_expressions_mut(f)?;
285        for expr in self.partition_by.iter_mut() {
286            f(expr)?;
287        }
288        for expr in self.order_by.iter_mut() {
289            f(expr)?;
290        }
291        Ok(())
292    }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296    fn visit_children<F>(&self, mut f: F)
297    where
298        F: FnMut(&HirScalarExpr),
299    {
300        self.func.visit_children(&mut f);
301        for expr in self.partition_by.iter() {
302            f(expr);
303        }
304        for expr in self.order_by.iter() {
305            f(expr);
306        }
307    }
308
309    fn visit_mut_children<F>(&mut self, mut f: F)
310    where
311        F: FnMut(&mut HirScalarExpr),
312    {
313        self.func.visit_mut_children(&mut f);
314        for expr in self.partition_by.iter_mut() {
315            f(expr);
316        }
317        for expr in self.order_by.iter_mut() {
318            f(expr);
319        }
320    }
321
322    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323    where
324        F: FnMut(&HirScalarExpr) -> Result<(), E>,
325        E: From<RecursionLimitError>,
326    {
327        self.func.try_visit_children(&mut f)?;
328        for expr in self.partition_by.iter() {
329            f(expr)?;
330        }
331        for expr in self.order_by.iter() {
332            f(expr)?;
333        }
334        Ok(())
335    }
336
337    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338    where
339        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340        E: From<RecursionLimitError>,
341    {
342        self.func.try_visit_mut_children(&mut f)?;
343        for expr in self.partition_by.iter_mut() {
344            f(expr)?;
345        }
346        for expr in self.order_by.iter_mut() {
347            f(expr)?;
348        }
349        Ok(())
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354/// A window function with its parameters.
355///
356/// There are three types of window functions:
357/// - scalar window functions, which return a different scalar value for each
358///   row within a partition that depends exclusively on the position of the row
359///   within the partition;
360/// - value window functions, which return a scalar value for each row within a
361///   partition that might be computed based on a single row, which is usually not
362///   the current row (e.g., previous or following row; first or last row of the
363///   partition);
364/// - aggregate window functions, which compute a traditional aggregation as a
365///   window function (e.g. `sum(x) OVER (...)`).
366///   (Aggregate window  functions can in some cases be computed by joining the
367///   input relation with a reduction over the same relation that computes the
368///   aggregation using the partition key as its grouping key, but we don't
369///   automatically do this currently.)
370pub enum WindowExprType {
371    Scalar(ScalarWindowExpr),
372    Value(ValueWindowExpr),
373    Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379    where
380        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381    {
382        #[allow(deprecated)]
383        match self {
384            Self::Scalar(expr) => expr.visit_expressions(f),
385            Self::Value(expr) => expr.visit_expressions(f),
386            Self::Aggregate(expr) => expr.visit_expressions(f),
387        }
388    }
389
390    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392    where
393        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394    {
395        #[allow(deprecated)]
396        match self {
397            Self::Scalar(expr) => expr.visit_expressions_mut(f),
398            Self::Value(expr) => expr.visit_expressions_mut(f),
399            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400        }
401    }
402
403    fn typ(
404        &self,
405        outers: &[SqlRelationType],
406        inner: &SqlRelationType,
407        params: &BTreeMap<usize, SqlScalarType>,
408    ) -> SqlColumnType {
409        match self {
410            Self::Scalar(expr) => expr.typ(outers, inner, params),
411            Self::Value(expr) => expr.typ(outers, inner, params),
412            Self::Aggregate(expr) => expr.typ(outers, inner, params),
413        }
414    }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418    fn visit_children<F>(&self, f: F)
419    where
420        F: FnMut(&HirScalarExpr),
421    {
422        match self {
423            Self::Scalar(_) => (),
424            Self::Value(expr) => expr.visit_children(f),
425            Self::Aggregate(expr) => expr.visit_children(f),
426        }
427    }
428
429    fn visit_mut_children<F>(&mut self, f: F)
430    where
431        F: FnMut(&mut HirScalarExpr),
432    {
433        match self {
434            Self::Scalar(_) => (),
435            Self::Value(expr) => expr.visit_mut_children(f),
436            Self::Aggregate(expr) => expr.visit_mut_children(f),
437        }
438    }
439
440    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441    where
442        F: FnMut(&HirScalarExpr) -> Result<(), E>,
443        E: From<RecursionLimitError>,
444    {
445        match self {
446            Self::Scalar(_) => Ok(()),
447            Self::Value(expr) => expr.try_visit_children(f),
448            Self::Aggregate(expr) => expr.try_visit_children(f),
449        }
450    }
451
452    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453    where
454        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455        E: From<RecursionLimitError>,
456    {
457        match self {
458            Self::Scalar(_) => Ok(()),
459            Self::Value(expr) => expr.try_visit_mut_children(f),
460            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461        }
462    }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467    pub func: ScalarWindowFunc,
468    pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474    where
475        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476    {
477        match self.func {
478            ScalarWindowFunc::RowNumber => {}
479            ScalarWindowFunc::Rank => {}
480            ScalarWindowFunc::DenseRank => {}
481        }
482        Ok(())
483    }
484
485    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487    where
488        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489    {
490        match self.func {
491            ScalarWindowFunc::RowNumber => {}
492            ScalarWindowFunc::Rank => {}
493            ScalarWindowFunc::DenseRank => {}
494        }
495        Ok(())
496    }
497
498    fn typ(
499        &self,
500        _outers: &[SqlRelationType],
501        _inner: &SqlRelationType,
502        _params: &BTreeMap<usize, SqlScalarType>,
503    ) -> SqlColumnType {
504        self.func.output_type()
505    }
506
507    pub fn into_expr(self) -> mz_expr::AggregateFunc {
508        match self.func {
509            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510                order_by: self.order_by,
511            },
512            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513                order_by: self.order_by,
514            },
515            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516                order_by: self.order_by,
517            },
518        }
519    }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523/// Scalar Window functions
524pub enum ScalarWindowFunc {
525    RowNumber,
526    Rank,
527    DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532        match self {
533            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534            ScalarWindowFunc::Rank => write!(f, "rank"),
535            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536        }
537    }
538}
539
540impl ScalarWindowFunc {
541    pub fn output_type(&self) -> SqlColumnType {
542        match self {
543            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
544            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
545            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
546        }
547    }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552    pub func: ValueWindowFunc,
553    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
554    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
555    /// e.g., `row(#1, 3, null)`.
556    /// If it's a fused window function, then the arguments of each of the constituent function
557    /// calls are wrapped in an outer record.
558    pub args: Box<HirScalarExpr>,
559    /// See comment on `WindowExpr::order_by`.
560    pub order_by: Vec<ColumnOrder>,
561    pub window_frame: WindowFrame,
562    pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567        match self {
568            ValueWindowFunc::Lag => write!(f, "lag"),
569            ValueWindowFunc::Lead => write!(f, "lead"),
570            ValueWindowFunc::FirstValue => write!(f, "first_value"),
571            ValueWindowFunc::LastValue => write!(f, "last_value"),
572            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573        }
574    }
575}
576
577impl ValueWindowExpr {
578    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580    where
581        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582    {
583        f(&self.args)
584    }
585
586    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588    where
589        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590    {
591        f(&mut self.args)
592    }
593
594    fn typ(
595        &self,
596        outers: &[SqlRelationType],
597        inner: &SqlRelationType,
598        params: &BTreeMap<usize, SqlScalarType>,
599    ) -> SqlColumnType {
600        self.func.output_type(self.args.typ(outers, inner, params))
601    }
602
603    /// Converts into `mz_expr::AggregateFunc`.
604    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605        (
606            self.args,
607            self.func
608                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609        )
610    }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614    fn visit_children<F>(&self, mut f: F)
615    where
616        F: FnMut(&HirScalarExpr),
617    {
618        f(&self.args)
619    }
620
621    fn visit_mut_children<F>(&mut self, mut f: F)
622    where
623        F: FnMut(&mut HirScalarExpr),
624    {
625        f(&mut self.args)
626    }
627
628    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629    where
630        F: FnMut(&HirScalarExpr) -> Result<(), E>,
631        E: From<RecursionLimitError>,
632    {
633        f(&self.args)
634    }
635
636    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637    where
638        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639        E: From<RecursionLimitError>,
640    {
641        f(&mut self.args)
642    }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646/// Value Window functions
647pub enum ValueWindowFunc {
648    Lag,
649    Lead,
650    FirstValue,
651    LastValue,
652    Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
657        match self {
658            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659                // The input is a (value, offset, default) record, so extract the type of the first arg
660                input_type.scalar_type.unwrap_record_element_type()[0]
661                    .clone()
662                    .nullable(true)
663            }
664            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665                input_type.scalar_type.nullable(true)
666            }
667            ValueWindowFunc::Fused(funcs) => {
668                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669                SqlScalarType::Record {
670                    fields: funcs
671                        .iter()
672                        .zip_eq(input_types)
673                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674                        .collect(),
675                    custom_id: None,
676                }
677                .nullable(false)
678            }
679        }
680    }
681
682    pub fn into_expr(
683        self,
684        order_by: Vec<ColumnOrder>,
685        window_frame: WindowFrame,
686        ignore_nulls: bool,
687    ) -> mz_expr::AggregateFunc {
688        match self {
689            // Lag and Lead are fundamentally the same function, just with opposite directions
690            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691                order_by,
692                lag_lead: mz_expr::LagLeadType::Lag,
693                ignore_nulls,
694            },
695            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696                order_by,
697                lag_lead: mz_expr::LagLeadType::Lead,
698                ignore_nulls,
699            },
700            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701                order_by,
702                window_frame,
703            },
704            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705                order_by,
706                window_frame,
707            },
708            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709                funcs: funcs
710                    .into_iter()
711                    .map(|func| {
712                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713                    })
714                    .collect(),
715                order_by,
716            },
717        }
718    }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723    pub aggregate_expr: AggregateExpr,
724    pub order_by: Vec<ColumnOrder>,
725    pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731    where
732        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733    {
734        f(&self.aggregate_expr.expr)
735    }
736
737    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739    where
740        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741    {
742        f(&mut self.aggregate_expr.expr)
743    }
744
745    fn typ(
746        &self,
747        outers: &[SqlRelationType],
748        inner: &SqlRelationType,
749        params: &BTreeMap<usize, SqlScalarType>,
750    ) -> SqlColumnType {
751        self.aggregate_expr
752            .func
753            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754    }
755
756    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758            (
759                self.aggregate_expr.expr,
760                FusedWindowAggregate {
761                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762                    order_by: self.order_by,
763                    window_frame: self.window_frame,
764                },
765            )
766        } else {
767            (
768                self.aggregate_expr.expr,
769                WindowAggregate {
770                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771                    order_by: self.order_by,
772                    window_frame: self.window_frame,
773                },
774            )
775        }
776    }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780    fn visit_children<F>(&self, mut f: F)
781    where
782        F: FnMut(&HirScalarExpr),
783    {
784        f(&self.aggregate_expr.expr)
785    }
786
787    fn visit_mut_children<F>(&mut self, mut f: F)
788    where
789        F: FnMut(&mut HirScalarExpr),
790    {
791        f(&mut self.aggregate_expr.expr)
792    }
793
794    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795    where
796        F: FnMut(&HirScalarExpr) -> Result<(), E>,
797        E: From<RecursionLimitError>,
798    {
799        f(&self.aggregate_expr.expr)
800    }
801
802    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803    where
804        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805        E: From<RecursionLimitError>,
806    {
807        f(&mut self.aggregate_expr.expr)
808    }
809}
810
811/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
812/// determined. Several SQL expressions can be freely coerced based upon where
813/// in the expression tree they appear. For example, the string literal '42'
814/// will be automatically coerced to the integer 42 if used in a numeric
815/// context:
816///
817/// ```sql
818/// SELECT '42' + 42
819/// ```
820///
821/// This separate type gives the code that needs to interact with coercions very
822/// fine-grained control over what coercions happen and when.
823///
824/// The primary driver of coercion is function and operator selection, as
825/// choosing the correct function or operator implementation depends on the type
826/// of the provided arguments. Coercion also occurs at the very root of the
827/// scalar expression tree. For example in
828///
829/// ```sql
830/// SELECT ... WHERE $1
831/// ```
832///
833/// the `WHERE` clause will coerce the contained unconstrained type parameter
834/// `$1` to have type bool.
835#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837    Coerced(HirScalarExpr),
838    Parameter(usize),
839    LiteralNull,
840    LiteralString(String),
841    LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845    pub fn type_as(
846        self,
847        ecx: &ExprContext,
848        ty: &SqlScalarType,
849    ) -> Result<HirScalarExpr, PlanError> {
850        let expr = typeconv::plan_coerce(ecx, self, ty)?;
851        let expr_ty = ecx.scalar_type(&expr);
852        if ty != &expr_ty {
853            sql_bail!(
854                "{} must have type {}, not type {}",
855                ecx.name,
856                ecx.humanize_scalar_type(ty, false),
857                ecx.humanize_scalar_type(&expr_ty, false),
858            );
859        }
860        Ok(expr)
861    }
862
863    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
864        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
865    }
866
867    pub fn cast_to(
868        self,
869        ecx: &ExprContext,
870        ccx: CastContext,
871        ty: &SqlScalarType,
872    ) -> Result<HirScalarExpr, PlanError> {
873        let expr = typeconv::plan_coerce(ecx, self, ty)?;
874        typeconv::plan_cast(ecx, ccx, expr, ty)
875    }
876}
877
878/// The column type for a [`CoercibleScalarExpr`].
879#[derive(Clone, Debug)]
880pub enum CoercibleColumnType {
881    Coerced(SqlColumnType),
882    Record(Vec<CoercibleColumnType>),
883    Uncoerced,
884}
885
886impl CoercibleColumnType {
887    /// Reports the nullability of the type.
888    pub fn nullable(&self) -> bool {
889        match self {
890            // A coerced value's nullability is known.
891            CoercibleColumnType::Coerced(ct) => ct.nullable,
892
893            // A literal record can never be null.
894            CoercibleColumnType::Record(_) => false,
895
896            // An uncoerced literal may be the literal `NULL`, so we have
897            // to conservatively assume it is nullable.
898            CoercibleColumnType::Uncoerced => true,
899        }
900    }
901}
902
903/// The scalar type for a [`CoercibleScalarExpr`].
904#[derive(Clone, Debug)]
905pub enum CoercibleScalarType {
906    Coerced(SqlScalarType),
907    Record(Vec<CoercibleColumnType>),
908    Uncoerced,
909}
910
911impl CoercibleScalarType {
912    /// Reports whether the scalar type has been coerced.
913    pub fn is_coerced(&self) -> bool {
914        matches!(self, CoercibleScalarType::Coerced(_))
915    }
916
917    /// Returns the coerced scalar type, if the type is coerced.
918    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
919        match self {
920            CoercibleScalarType::Coerced(t) => Some(t),
921            _ => None,
922        }
923    }
924
925    /// If the type is coerced, apply the mapping function to the contained
926    /// scalar type.
927    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
928    where
929        F: FnOnce(SqlScalarType) -> SqlScalarType,
930    {
931        match self {
932            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
933            _ => self,
934        }
935    }
936
937    /// If the type is an coercible record, forcibly converts to a coerced
938    /// record type. Any uncoerced field types are assumed to be of type text.
939    ///
940    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
941    /// accepts a type hint that can indicate the types of uncoerced field
942    /// types.
943    pub fn force_coerced_if_record(&mut self) {
944        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
945            let mut fields = vec![];
946            for (i, uf) in uncoerced_fields.enumerate() {
947                let name = ColumnName::from(format!("f{}", i + 1));
948                let ty = match uf {
949                    CoercibleColumnType::Coerced(ty) => ty,
950                    CoercibleColumnType::Record(mut fields) => {
951                        convert(fields.drain(..)).nullable(false)
952                    }
953                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
954                };
955                fields.push((name, ty))
956            }
957            SqlScalarType::Record {
958                fields: fields.into(),
959                custom_id: None,
960            }
961        }
962
963        if let CoercibleScalarType::Record(fields) = self {
964            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
965        }
966    }
967}
968
969/// An expression whose type can be ascertained.
970///
971/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
972pub trait AbstractExpr {
973    type Type: AbstractColumnType;
974
975    /// Computes the type of the expression.
976    fn typ(
977        &self,
978        outers: &[SqlRelationType],
979        inner: &SqlRelationType,
980        params: &BTreeMap<usize, SqlScalarType>,
981    ) -> Self::Type;
982}
983
984impl AbstractExpr for CoercibleScalarExpr {
985    type Type = CoercibleColumnType;
986
987    fn typ(
988        &self,
989        outers: &[SqlRelationType],
990        inner: &SqlRelationType,
991        params: &BTreeMap<usize, SqlScalarType>,
992    ) -> Self::Type {
993        match self {
994            CoercibleScalarExpr::Coerced(expr) => {
995                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
996            }
997            CoercibleScalarExpr::LiteralRecord(scalars) => {
998                let fields = scalars
999                    .iter()
1000                    .map(|s| s.typ(outers, inner, params))
1001                    .collect();
1002                CoercibleColumnType::Record(fields)
1003            }
1004            _ => CoercibleColumnType::Uncoerced,
1005        }
1006    }
1007}
1008
1009/// A column type-like object whose underlying scalar type-like object can be
1010/// ascertained.
1011///
1012/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1013pub trait AbstractColumnType {
1014    type AbstractScalarType;
1015
1016    /// Converts the column type-like object into its inner scalar type-like
1017    /// object.
1018    fn scalar_type(self) -> Self::AbstractScalarType;
1019}
1020
1021impl AbstractColumnType for SqlColumnType {
1022    type AbstractScalarType = SqlScalarType;
1023
1024    fn scalar_type(self) -> Self::AbstractScalarType {
1025        self.scalar_type
1026    }
1027}
1028
1029impl AbstractColumnType for CoercibleColumnType {
1030    type AbstractScalarType = CoercibleScalarType;
1031
1032    fn scalar_type(self) -> Self::AbstractScalarType {
1033        match self {
1034            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1035            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1036            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1037        }
1038    }
1039}
1040
1041impl From<HirScalarExpr> for CoercibleScalarExpr {
1042    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1043        CoercibleScalarExpr::Coerced(expr)
1044    }
1045}
1046
1047/// A leveled column reference.
1048///
1049/// In the course of decorrelation, multiple levels of nested subqueries are
1050/// traversed, and references to columns may correspond to different levels
1051/// of containing outer subqueries.
1052///
1053/// A `ColumnRef` allows expressions to refer to columns while being clear
1054/// about which level the column references without manually performing the
1055/// bookkeeping tracking their actual column locations.
1056///
1057/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1058/// from the reference, using `column` as a unique identifier in that subquery level.
1059/// A `level` of zero corresponds to the current scope, and levels increase to
1060/// indicate subqueries further "outwards".
1061#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1062pub struct ColumnRef {
1063    // scope level, where 0 is the current scope and 1+ are outer scopes.
1064    pub level: usize,
1065    // level-local column identifier used.
1066    pub column: usize,
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1070pub enum JoinKind {
1071    Inner,
1072    LeftOuter,
1073    RightOuter,
1074    FullOuter,
1075}
1076
1077impl fmt::Display for JoinKind {
1078    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1079        write!(
1080            f,
1081            "{}",
1082            match self {
1083                JoinKind::Inner => "Inner",
1084                JoinKind::LeftOuter => "LeftOuter",
1085                JoinKind::RightOuter => "RightOuter",
1086                JoinKind::FullOuter => "FullOuter",
1087            }
1088        )
1089    }
1090}
1091
1092impl JoinKind {
1093    pub fn can_be_correlated(&self) -> bool {
1094        match self {
1095            JoinKind::Inner | JoinKind::LeftOuter => true,
1096            JoinKind::RightOuter | JoinKind::FullOuter => false,
1097        }
1098    }
1099}
1100
1101#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1102pub struct AggregateExpr {
1103    pub func: AggregateFunc,
1104    pub expr: Box<HirScalarExpr>,
1105    pub distinct: bool,
1106}
1107
1108/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1109/// types may be different.
1110///
1111/// Specifically, the nullability of the aggregate columns is more common
1112/// here than in `expr`, as these aggregates may be applied over empty
1113/// result sets and should be null in those cases, whereas `expr` variants
1114/// only return null values when supplied nulls as input.
1115#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1116pub enum AggregateFunc {
1117    MaxNumeric,
1118    MaxInt16,
1119    MaxInt32,
1120    MaxInt64,
1121    MaxUInt16,
1122    MaxUInt32,
1123    MaxUInt64,
1124    MaxMzTimestamp,
1125    MaxFloat32,
1126    MaxFloat64,
1127    MaxBool,
1128    MaxString,
1129    MaxDate,
1130    MaxTimestamp,
1131    MaxTimestampTz,
1132    MaxInterval,
1133    MaxTime,
1134    MinNumeric,
1135    MinInt16,
1136    MinInt32,
1137    MinInt64,
1138    MinUInt16,
1139    MinUInt32,
1140    MinUInt64,
1141    MinMzTimestamp,
1142    MinFloat32,
1143    MinFloat64,
1144    MinBool,
1145    MinString,
1146    MinDate,
1147    MinTimestamp,
1148    MinTimestampTz,
1149    MinInterval,
1150    MinTime,
1151    SumInt16,
1152    SumInt32,
1153    SumInt64,
1154    SumUInt16,
1155    SumUInt32,
1156    SumUInt64,
1157    SumFloat32,
1158    SumFloat64,
1159    SumNumeric,
1160    Count,
1161    Any,
1162    All,
1163    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1164    /// into a JSON list. The other elements are columns used by `order_by`.
1165    ///
1166    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1167    /// layer, this function filters out `Datum::Null`, for consistency with
1168    /// the other aggregate functions.
1169    JsonbAgg {
1170        order_by: Vec<ColumnOrder>,
1171    },
1172    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1173    /// JSON map. The other elements are columns used by `order_by`.
1174    JsonbObjectAgg {
1175        order_by: Vec<ColumnOrder>,
1176    },
1177    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1178    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1179    /// elements are columns used by `order_by`.
1180    MapAgg {
1181        order_by: Vec<ColumnOrder>,
1182        value_type: SqlScalarType,
1183    },
1184    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1185    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1186    ArrayConcat {
1187        order_by: Vec<ColumnOrder>,
1188    },
1189    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1190    /// single `Datum::List`. The other elements are columns used by `order_by`.
1191    ListConcat {
1192        order_by: Vec<ColumnOrder>,
1193    },
1194    StringAgg {
1195        order_by: Vec<ColumnOrder>,
1196    },
1197    /// A bundle of fused window aggregations: its input is a record, whose each
1198    /// component will be the input to one of the `AggregateFunc`s.
1199    ///
1200    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1201    /// more specifically an `AggregateWindowExpr`.
1202    FusedWindowAgg {
1203        funcs: Vec<AggregateFunc>,
1204    },
1205    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1206    ///
1207    /// Useful for removing an expensive aggregation while maintaining the shape
1208    /// of a reduce operator.
1209    Dummy,
1210}
1211
1212impl AggregateFunc {
1213    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1214    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1215        match self {
1216            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1217            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1218            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1219            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1220            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1221            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1222            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1223            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1224            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1225            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1226            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1227            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1228            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1229            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1230            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1231            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1232            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1233            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1234            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1235            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1236            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1237            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1238            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1239            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1240            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1241            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1242            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1243            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1244            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1245            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1246            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1247            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1248            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1249            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1250            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1251            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1252            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1253            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1254            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1255            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1256            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1257            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1258            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1259            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1260            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1261            AggregateFunc::All => mz_expr::AggregateFunc::All,
1262            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1263            AggregateFunc::JsonbObjectAgg { order_by } => {
1264                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1265            }
1266            AggregateFunc::MapAgg {
1267                order_by,
1268                value_type,
1269            } => mz_expr::AggregateFunc::MapAgg {
1270                order_by,
1271                value_type,
1272            },
1273            AggregateFunc::ArrayConcat { order_by } => {
1274                mz_expr::AggregateFunc::ArrayConcat { order_by }
1275            }
1276            AggregateFunc::ListConcat { order_by } => {
1277                mz_expr::AggregateFunc::ListConcat { order_by }
1278            }
1279            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1280            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1281            // `AggregateWindowExpr::into_expr`.
1282            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1283                panic!("into_expr called on FusedWindowAgg")
1284            }
1285            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1286        }
1287    }
1288
1289    /// Returns a datum whose inclusion in the aggregation will not change its
1290    /// result.
1291    ///
1292    /// # Panics
1293    ///
1294    /// Panics if called on a `FusedWindowAgg`.
1295    pub fn identity_datum(&self) -> Datum<'static> {
1296        match self {
1297            AggregateFunc::Any => Datum::False,
1298            AggregateFunc::All => Datum::True,
1299            AggregateFunc::Dummy => Datum::Dummy,
1300            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1301            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1302            AggregateFunc::MaxNumeric
1303            | AggregateFunc::MaxInt16
1304            | AggregateFunc::MaxInt32
1305            | AggregateFunc::MaxInt64
1306            | AggregateFunc::MaxUInt16
1307            | AggregateFunc::MaxUInt32
1308            | AggregateFunc::MaxUInt64
1309            | AggregateFunc::MaxMzTimestamp
1310            | AggregateFunc::MaxFloat32
1311            | AggregateFunc::MaxFloat64
1312            | AggregateFunc::MaxBool
1313            | AggregateFunc::MaxString
1314            | AggregateFunc::MaxDate
1315            | AggregateFunc::MaxTimestamp
1316            | AggregateFunc::MaxTimestampTz
1317            | AggregateFunc::MaxInterval
1318            | AggregateFunc::MaxTime
1319            | AggregateFunc::MinNumeric
1320            | AggregateFunc::MinInt16
1321            | AggregateFunc::MinInt32
1322            | AggregateFunc::MinInt64
1323            | AggregateFunc::MinUInt16
1324            | AggregateFunc::MinUInt32
1325            | AggregateFunc::MinUInt64
1326            | AggregateFunc::MinMzTimestamp
1327            | AggregateFunc::MinFloat32
1328            | AggregateFunc::MinFloat64
1329            | AggregateFunc::MinBool
1330            | AggregateFunc::MinString
1331            | AggregateFunc::MinDate
1332            | AggregateFunc::MinTimestamp
1333            | AggregateFunc::MinTimestampTz
1334            | AggregateFunc::MinInterval
1335            | AggregateFunc::MinTime
1336            | AggregateFunc::SumInt16
1337            | AggregateFunc::SumInt32
1338            | AggregateFunc::SumInt64
1339            | AggregateFunc::SumUInt16
1340            | AggregateFunc::SumUInt32
1341            | AggregateFunc::SumUInt64
1342            | AggregateFunc::SumFloat32
1343            | AggregateFunc::SumFloat64
1344            | AggregateFunc::SumNumeric
1345            | AggregateFunc::Count
1346            | AggregateFunc::JsonbAgg { .. }
1347            | AggregateFunc::JsonbObjectAgg { .. }
1348            | AggregateFunc::MapAgg { .. }
1349            | AggregateFunc::StringAgg { .. } => Datum::Null,
1350            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1351                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1352                // in HIR planning, because it is introduced only during HIR transformation.
1353                //
1354                // The implementation could be something like the following, except that we need to
1355                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1356                // ```
1357                // let temp_storage = RowArena::new();
1358                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1359                // ```
1360                panic!("FusedWindowAgg doesn't have an identity_datum")
1361            }
1362        }
1363    }
1364
1365    /// The output column type for the result of an aggregation.
1366    ///
1367    /// The output column type also contains nullability information, which
1368    /// is (without further information) true for aggregations that are not
1369    /// counts.
1370    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1371        let scalar_type = match self {
1372            AggregateFunc::Count => SqlScalarType::Int64,
1373            AggregateFunc::Any => SqlScalarType::Bool,
1374            AggregateFunc::All => SqlScalarType::Bool,
1375            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1376            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1377            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1378            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1379            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1380                max_scale: Some(NumericMaxScale::ZERO),
1381            },
1382            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1383            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1384                max_scale: Some(NumericMaxScale::ZERO),
1385            },
1386            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1387                value_type: Box::new(value_type.clone()),
1388                custom_id: None,
1389            },
1390            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1391                match input_type.scalar_type {
1392                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1393                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1394                    _ => unreachable!(),
1395                }
1396            }
1397            AggregateFunc::MaxNumeric
1398            | AggregateFunc::MaxInt16
1399            | AggregateFunc::MaxInt32
1400            | AggregateFunc::MaxInt64
1401            | AggregateFunc::MaxUInt16
1402            | AggregateFunc::MaxUInt32
1403            | AggregateFunc::MaxUInt64
1404            | AggregateFunc::MaxMzTimestamp
1405            | AggregateFunc::MaxFloat32
1406            | AggregateFunc::MaxFloat64
1407            | AggregateFunc::MaxBool
1408            | AggregateFunc::MaxString
1409            | AggregateFunc::MaxDate
1410            | AggregateFunc::MaxTimestamp
1411            | AggregateFunc::MaxTimestampTz
1412            | AggregateFunc::MaxInterval
1413            | AggregateFunc::MaxTime
1414            | AggregateFunc::MinNumeric
1415            | AggregateFunc::MinInt16
1416            | AggregateFunc::MinInt32
1417            | AggregateFunc::MinInt64
1418            | AggregateFunc::MinUInt16
1419            | AggregateFunc::MinUInt32
1420            | AggregateFunc::MinUInt64
1421            | AggregateFunc::MinMzTimestamp
1422            | AggregateFunc::MinFloat32
1423            | AggregateFunc::MinFloat64
1424            | AggregateFunc::MinBool
1425            | AggregateFunc::MinString
1426            | AggregateFunc::MinDate
1427            | AggregateFunc::MinTimestamp
1428            | AggregateFunc::MinTimestampTz
1429            | AggregateFunc::MinInterval
1430            | AggregateFunc::MinTime
1431            | AggregateFunc::SumFloat32
1432            | AggregateFunc::SumFloat64
1433            | AggregateFunc::SumNumeric
1434            | AggregateFunc::Dummy => input_type.scalar_type,
1435            AggregateFunc::FusedWindowAgg { funcs } => {
1436                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1437                SqlScalarType::Record {
1438                    fields: funcs
1439                        .iter()
1440                        .zip_eq(input_types)
1441                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1442                        .collect(),
1443                    custom_id: None,
1444                }
1445            }
1446        };
1447        // max/min/sum return null on empty sets
1448        let nullable = !matches!(self, AggregateFunc::Count);
1449        scalar_type.nullable(nullable)
1450    }
1451
1452    pub fn is_order_sensitive(&self) -> bool {
1453        use AggregateFunc::*;
1454        matches!(
1455            self,
1456            JsonbAgg { .. }
1457                | JsonbObjectAgg { .. }
1458                | MapAgg { .. }
1459                | ArrayConcat { .. }
1460                | ListConcat { .. }
1461                | StringAgg { .. }
1462        )
1463    }
1464}
1465
1466impl HirRelationExpr {
1467    pub fn typ(
1468        &self,
1469        outers: &[SqlRelationType],
1470        params: &BTreeMap<usize, SqlScalarType>,
1471    ) -> SqlRelationType {
1472        stack::maybe_grow(|| match self {
1473            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1474            HirRelationExpr::Get { typ, .. } => typ.clone(),
1475            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1476            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1477            HirRelationExpr::Project { input, outputs } => {
1478                let input_typ = input.typ(outers, params);
1479                SqlRelationType::new(
1480                    outputs
1481                        .iter()
1482                        .map(|&i| input_typ.column_types[i].clone())
1483                        .collect(),
1484                )
1485            }
1486            HirRelationExpr::Map { input, scalars } => {
1487                let mut typ = input.typ(outers, params);
1488                for scalar in scalars {
1489                    typ.column_types.push(scalar.typ(outers, &typ, params));
1490                }
1491                typ
1492            }
1493            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1494            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1495                input.typ(outers, params)
1496            }
1497            HirRelationExpr::Join {
1498                left, right, kind, ..
1499            } => {
1500                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1501                let right_nullable =
1502                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1503                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1504                    let nullable = t.nullable || left_nullable;
1505                    t.nullable(nullable)
1506                });
1507                let mut outers = outers.to_vec();
1508                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1509                let rt = right
1510                    .typ(&outers, params)
1511                    .column_types
1512                    .into_iter()
1513                    .map(|t| {
1514                        let nullable = t.nullable || right_nullable;
1515                        t.nullable(nullable)
1516                    });
1517                SqlRelationType::new(lt.chain(rt).collect())
1518            }
1519            HirRelationExpr::Reduce {
1520                input,
1521                group_key,
1522                aggregates,
1523                expected_group_size: _,
1524            } => {
1525                let input_typ = input.typ(outers, params);
1526                let mut column_types = group_key
1527                    .iter()
1528                    .map(|&i| input_typ.column_types[i].clone())
1529                    .collect::<Vec<_>>();
1530                for agg in aggregates {
1531                    column_types.push(agg.typ(outers, &input_typ, params));
1532                }
1533                // TODO(frank): add primary key information.
1534                SqlRelationType::new(column_types)
1535            }
1536            // TODO(frank): check for removal; add primary key information.
1537            HirRelationExpr::Distinct { input }
1538            | HirRelationExpr::Negate { input }
1539            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1540            HirRelationExpr::Union { base, inputs } => {
1541                let mut base_cols = base.typ(outers, params).column_types;
1542                for input in inputs {
1543                    for (base_col, col) in base_cols
1544                        .iter_mut()
1545                        .zip_eq(input.typ(outers, params).column_types)
1546                    {
1547                        *base_col = base_col.union(&col).unwrap();
1548                    }
1549                }
1550                SqlRelationType::new(base_cols)
1551            }
1552        })
1553    }
1554
1555    pub fn arity(&self) -> usize {
1556        match self {
1557            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1558            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1559            HirRelationExpr::Let { body, .. } => body.arity(),
1560            HirRelationExpr::LetRec { body, .. } => body.arity(),
1561            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1562            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1563            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1564            HirRelationExpr::Filter { input, .. }
1565            | HirRelationExpr::TopK { input, .. }
1566            | HirRelationExpr::Distinct { input }
1567            | HirRelationExpr::Negate { input }
1568            | HirRelationExpr::Threshold { input } => input.arity(),
1569            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1570            HirRelationExpr::Union { base, .. } => base.arity(),
1571            HirRelationExpr::Reduce {
1572                group_key,
1573                aggregates,
1574                ..
1575            } => group_key.len() + aggregates.len(),
1576        }
1577    }
1578
1579    /// If self is a constant, return the value and the type, otherwise `None`.
1580    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1581        match self {
1582            Self::Constant { rows, typ } => Some((rows, typ)),
1583            _ => None,
1584        }
1585    }
1586
1587    /// Reports whether this expression contains a column reference to its
1588    /// direct parent scope.
1589    pub fn is_correlated(&self) -> bool {
1590        let mut correlated = false;
1591        #[allow(deprecated)]
1592        self.visit_columns(0, &mut |depth, col| {
1593            if col.level > depth && col.level - depth == 1 {
1594                correlated = true;
1595            }
1596        });
1597        correlated
1598    }
1599
1600    pub fn is_join_identity(&self) -> bool {
1601        match self {
1602            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1603            _ => false,
1604        }
1605    }
1606
1607    pub fn project(self, outputs: Vec<usize>) -> Self {
1608        if outputs.iter().copied().eq(0..self.arity()) {
1609            // The projection is trivial. Suppress it.
1610            self
1611        } else {
1612            HirRelationExpr::Project {
1613                input: Box::new(self),
1614                outputs,
1615            }
1616        }
1617    }
1618
1619    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1620        if scalars.is_empty() {
1621            // The map is trivial. Suppress it.
1622            self
1623        } else if let HirRelationExpr::Map {
1624            scalars: old_scalars,
1625            input: _,
1626        } = &mut self
1627        {
1628            // Map applied to a map. Fuse the maps.
1629            old_scalars.extend(scalars);
1630            self
1631        } else {
1632            HirRelationExpr::Map {
1633                input: Box::new(self),
1634                scalars,
1635            }
1636        }
1637    }
1638
1639    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1640        if let HirRelationExpr::Filter {
1641            input: _,
1642            predicates,
1643        } = &mut self
1644        {
1645            predicates.extend(preds);
1646            predicates.sort();
1647            predicates.dedup();
1648            self
1649        } else {
1650            preds.sort();
1651            preds.dedup();
1652            HirRelationExpr::Filter {
1653                input: Box::new(self),
1654                predicates: preds,
1655            }
1656        }
1657    }
1658
1659    pub fn reduce(
1660        self,
1661        group_key: Vec<usize>,
1662        aggregates: Vec<AggregateExpr>,
1663        expected_group_size: Option<u64>,
1664    ) -> Self {
1665        HirRelationExpr::Reduce {
1666            input: Box::new(self),
1667            group_key,
1668            aggregates,
1669            expected_group_size,
1670        }
1671    }
1672
1673    pub fn top_k(
1674        self,
1675        group_key: Vec<usize>,
1676        order_key: Vec<ColumnOrder>,
1677        limit: Option<HirScalarExpr>,
1678        offset: HirScalarExpr,
1679        expected_group_size: Option<u64>,
1680    ) -> Self {
1681        HirRelationExpr::TopK {
1682            input: Box::new(self),
1683            group_key,
1684            order_key,
1685            limit,
1686            offset,
1687            expected_group_size,
1688        }
1689    }
1690
1691    pub fn negate(self) -> Self {
1692        if let HirRelationExpr::Negate { input } = self {
1693            *input
1694        } else {
1695            HirRelationExpr::Negate {
1696                input: Box::new(self),
1697            }
1698        }
1699    }
1700
1701    pub fn distinct(self) -> Self {
1702        if let HirRelationExpr::Distinct { .. } = self {
1703            self
1704        } else {
1705            HirRelationExpr::Distinct {
1706                input: Box::new(self),
1707            }
1708        }
1709    }
1710
1711    pub fn threshold(self) -> Self {
1712        if let HirRelationExpr::Threshold { .. } = self {
1713            self
1714        } else {
1715            HirRelationExpr::Threshold {
1716                input: Box::new(self),
1717            }
1718        }
1719    }
1720
1721    pub fn union(self, other: Self) -> Self {
1722        let mut terms = Vec::new();
1723        if let HirRelationExpr::Union { base, inputs } = self {
1724            terms.push(*base);
1725            terms.extend(inputs);
1726        } else {
1727            terms.push(self);
1728        }
1729        if let HirRelationExpr::Union { base, inputs } = other {
1730            terms.push(*base);
1731            terms.extend(inputs);
1732        } else {
1733            terms.push(other);
1734        }
1735        HirRelationExpr::Union {
1736            base: Box::new(terms.remove(0)),
1737            inputs: terms,
1738        }
1739    }
1740
1741    pub fn exists(self) -> HirScalarExpr {
1742        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1743    }
1744
1745    pub fn select(self) -> HirScalarExpr {
1746        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1747    }
1748
1749    pub fn join(
1750        self,
1751        mut right: HirRelationExpr,
1752        on: HirScalarExpr,
1753        kind: JoinKind,
1754    ) -> HirRelationExpr {
1755        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1756        {
1757            // The join can be elided, but we need to adjust column references
1758            // on the right-hand side to account for the removal of the scope
1759            // introduced by the join.
1760            #[allow(deprecated)]
1761            right.visit_columns_mut(0, &mut |depth, col| {
1762                if col.level > depth {
1763                    col.level -= 1;
1764                }
1765            });
1766            right
1767        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1768            self
1769        } else {
1770            HirRelationExpr::Join {
1771                left: Box::new(self),
1772                right: Box::new(right),
1773                on,
1774                kind,
1775            }
1776        }
1777    }
1778
1779    pub fn take(&mut self) -> HirRelationExpr {
1780        mem::replace(
1781            self,
1782            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1783        )
1784    }
1785
1786    #[deprecated = "Use `Visit::visit_post`."]
1787    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1788    where
1789        F: FnMut(&'a Self, usize),
1790    {
1791        #[allow(deprecated)]
1792        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1793                                                 depth: usize|
1794         -> Result<(), ()> {
1795            f(e, depth);
1796            Ok(())
1797        });
1798    }
1799
1800    #[deprecated = "Use `Visit::try_visit_post`."]
1801    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1802    where
1803        F: FnMut(&'a Self, usize) -> Result<(), E>,
1804    {
1805        #[allow(deprecated)]
1806        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1807            e.visit_fallible(depth, f)
1808        })?;
1809        f(self, depth)
1810    }
1811
1812    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1813    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1814    where
1815        F: FnMut(&'a Self, usize) -> Result<(), E>,
1816    {
1817        match self {
1818            HirRelationExpr::Constant { .. }
1819            | HirRelationExpr::Get { .. }
1820            | HirRelationExpr::CallTable { .. } => (),
1821            HirRelationExpr::Let { body, value, .. } => {
1822                f(value, depth)?;
1823                f(body, depth)?;
1824            }
1825            HirRelationExpr::LetRec {
1826                limit: _,
1827                bindings,
1828                body,
1829            } => {
1830                for (_, _, value, _) in bindings.iter() {
1831                    f(value, depth)?;
1832                }
1833                f(body, depth)?;
1834            }
1835            HirRelationExpr::Project { input, .. } => {
1836                f(input, depth)?;
1837            }
1838            HirRelationExpr::Map { input, .. } => {
1839                f(input, depth)?;
1840            }
1841            HirRelationExpr::Filter { input, .. } => {
1842                f(input, depth)?;
1843            }
1844            HirRelationExpr::Join { left, right, .. } => {
1845                f(left, depth)?;
1846                f(right, depth + 1)?;
1847            }
1848            HirRelationExpr::Reduce { input, .. } => {
1849                f(input, depth)?;
1850            }
1851            HirRelationExpr::Distinct { input } => {
1852                f(input, depth)?;
1853            }
1854            HirRelationExpr::TopK { input, .. } => {
1855                f(input, depth)?;
1856            }
1857            HirRelationExpr::Negate { input } => {
1858                f(input, depth)?;
1859            }
1860            HirRelationExpr::Threshold { input } => {
1861                f(input, depth)?;
1862            }
1863            HirRelationExpr::Union { base, inputs } => {
1864                f(base, depth)?;
1865                for input in inputs {
1866                    f(input, depth)?;
1867                }
1868            }
1869        }
1870        Ok(())
1871    }
1872
1873    #[deprecated = "Use `Visit::visit_mut_post` instead."]
1874    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1875    where
1876        F: FnMut(&mut Self, usize),
1877    {
1878        #[allow(deprecated)]
1879        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1880                                                     depth: usize|
1881         -> Result<(), ()> {
1882            f(e, depth);
1883            Ok(())
1884        });
1885    }
1886
1887    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1888    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1889    where
1890        F: FnMut(&mut Self, usize) -> Result<(), E>,
1891    {
1892        #[allow(deprecated)]
1893        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1894            e.visit_mut_fallible(depth, f)
1895        })?;
1896        f(self, depth)
1897    }
1898
1899    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1900    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1901    where
1902        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1903    {
1904        match self {
1905            HirRelationExpr::Constant { .. }
1906            | HirRelationExpr::Get { .. }
1907            | HirRelationExpr::CallTable { .. } => (),
1908            HirRelationExpr::Let { body, value, .. } => {
1909                f(value, depth)?;
1910                f(body, depth)?;
1911            }
1912            HirRelationExpr::LetRec {
1913                limit: _,
1914                bindings,
1915                body,
1916            } => {
1917                for (_, _, value, _) in bindings.iter_mut() {
1918                    f(value, depth)?;
1919                }
1920                f(body, depth)?;
1921            }
1922            HirRelationExpr::Project { input, .. } => {
1923                f(input, depth)?;
1924            }
1925            HirRelationExpr::Map { input, .. } => {
1926                f(input, depth)?;
1927            }
1928            HirRelationExpr::Filter { input, .. } => {
1929                f(input, depth)?;
1930            }
1931            HirRelationExpr::Join { left, right, .. } => {
1932                f(left, depth)?;
1933                f(right, depth + 1)?;
1934            }
1935            HirRelationExpr::Reduce { input, .. } => {
1936                f(input, depth)?;
1937            }
1938            HirRelationExpr::Distinct { input } => {
1939                f(input, depth)?;
1940            }
1941            HirRelationExpr::TopK { input, .. } => {
1942                f(input, depth)?;
1943            }
1944            HirRelationExpr::Negate { input } => {
1945                f(input, depth)?;
1946            }
1947            HirRelationExpr::Threshold { input } => {
1948                f(input, depth)?;
1949            }
1950            HirRelationExpr::Union { base, inputs } => {
1951                f(base, depth)?;
1952                for input in inputs {
1953                    f(input, depth)?;
1954                }
1955            }
1956        }
1957        Ok(())
1958    }
1959
1960    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1961    /// Visits all scalar expressions within the sub-tree of the given relation.
1962    ///
1963    /// The `depth` argument should indicate the subquery nesting depth of the expression,
1964    /// which will be incremented when entering the RHS of a join or a subquery and
1965    /// presented to the supplied function `f`.
1966    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1967    where
1968        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1969    {
1970        #[allow(deprecated)]
1971        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1972                                         depth: usize|
1973         -> Result<(), E> {
1974            match e {
1975                HirRelationExpr::Join { on, .. } => {
1976                    f(on, depth)?;
1977                }
1978                HirRelationExpr::Map { scalars, .. } => {
1979                    for scalar in scalars {
1980                        f(scalar, depth)?;
1981                    }
1982                }
1983                HirRelationExpr::CallTable { exprs, .. } => {
1984                    for expr in exprs {
1985                        f(expr, depth)?;
1986                    }
1987                }
1988                HirRelationExpr::Filter { predicates, .. } => {
1989                    for predicate in predicates {
1990                        f(predicate, depth)?;
1991                    }
1992                }
1993                HirRelationExpr::Reduce { aggregates, .. } => {
1994                    for aggregate in aggregates {
1995                        f(&aggregate.expr, depth)?;
1996                    }
1997                }
1998                HirRelationExpr::TopK { limit, offset, .. } => {
1999                    if let Some(limit) = limit {
2000                        f(limit, depth)?;
2001                    }
2002                    f(offset, depth)?;
2003                }
2004                HirRelationExpr::Union { .. }
2005                | HirRelationExpr::Let { .. }
2006                | HirRelationExpr::LetRec { .. }
2007                | HirRelationExpr::Project { .. }
2008                | HirRelationExpr::Distinct { .. }
2009                | HirRelationExpr::Negate { .. }
2010                | HirRelationExpr::Threshold { .. }
2011                | HirRelationExpr::Constant { .. }
2012                | HirRelationExpr::Get { .. } => (),
2013            }
2014            Ok(())
2015        })
2016    }
2017
2018    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2019    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2020    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2021    where
2022        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2023    {
2024        #[allow(deprecated)]
2025        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2026                                             depth: usize|
2027         -> Result<(), E> {
2028            match e {
2029                HirRelationExpr::Join { on, .. } => {
2030                    f(on, depth)?;
2031                }
2032                HirRelationExpr::Map { scalars, .. } => {
2033                    for scalar in scalars.iter_mut() {
2034                        f(scalar, depth)?;
2035                    }
2036                }
2037                HirRelationExpr::CallTable { exprs, .. } => {
2038                    for expr in exprs.iter_mut() {
2039                        f(expr, depth)?;
2040                    }
2041                }
2042                HirRelationExpr::Filter { predicates, .. } => {
2043                    for predicate in predicates.iter_mut() {
2044                        f(predicate, depth)?;
2045                    }
2046                }
2047                HirRelationExpr::Reduce { aggregates, .. } => {
2048                    for aggregate in aggregates.iter_mut() {
2049                        f(&mut aggregate.expr, depth)?;
2050                    }
2051                }
2052                HirRelationExpr::TopK { limit, offset, .. } => {
2053                    if let Some(limit) = limit {
2054                        f(limit, depth)?;
2055                    }
2056                    f(offset, depth)?;
2057                }
2058                HirRelationExpr::Union { .. }
2059                | HirRelationExpr::Let { .. }
2060                | HirRelationExpr::LetRec { .. }
2061                | HirRelationExpr::Project { .. }
2062                | HirRelationExpr::Distinct { .. }
2063                | HirRelationExpr::Negate { .. }
2064                | HirRelationExpr::Threshold { .. }
2065                | HirRelationExpr::Constant { .. }
2066                | HirRelationExpr::Get { .. } => (),
2067            }
2068            Ok(())
2069        })
2070    }
2071
2072    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2073    /// Visits the column references in this relation expression.
2074    ///
2075    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2076    /// which will be incremented when entering the RHS of a join or a subquery and
2077    /// presented to the supplied function `f`.
2078    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2079    where
2080        F: FnMut(usize, &ColumnRef),
2081    {
2082        #[allow(deprecated)]
2083        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2084                                                           depth: usize|
2085         -> Result<(), ()> {
2086            e.visit_columns(depth, f);
2087            Ok(())
2088        });
2089    }
2090
2091    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2092    /// Like `visit_columns`, but permits mutating the column references.
2093    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2094    where
2095        F: FnMut(usize, &mut ColumnRef),
2096    {
2097        #[allow(deprecated)]
2098        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2099                                                               depth: usize|
2100         -> Result<(), ()> {
2101            e.visit_columns_mut(depth, f);
2102            Ok(())
2103        });
2104    }
2105
2106    /// Replaces any parameter references in the expression with the
2107    /// corresponding datum from `params`.
2108    pub fn bind_parameters(
2109        &mut self,
2110        scx: &StatementContext,
2111        lifetime: QueryLifetime,
2112        params: &Params,
2113    ) -> Result<(), PlanError> {
2114        #[allow(deprecated)]
2115        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2116            e.bind_parameters(scx, lifetime, params)
2117        })
2118    }
2119
2120    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2121        let mut contains_parameters = false;
2122        #[allow(deprecated)]
2123        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2124            if e.contains_parameters() {
2125                contains_parameters = true;
2126            }
2127            Ok::<(), PlanError>(())
2128        })?;
2129        Ok(contains_parameters)
2130    }
2131
2132    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2133    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2134        #[allow(deprecated)]
2135        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2136                                                               depth: usize|
2137         -> Result<(), ()> {
2138            e.splice_parameters(params, depth);
2139            Ok(())
2140        });
2141    }
2142
2143    /// Constructs a constant collection from specific rows and schema.
2144    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2145        let rows = rows
2146            .into_iter()
2147            .map(move |datums| Row::pack_slice(&datums))
2148            .collect();
2149        HirRelationExpr::Constant { rows, typ }
2150    }
2151
2152    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2153    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2154    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2155    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2156    /// finishing into a trivial finishing.
2157    pub fn finish_maintained(
2158        &mut self,
2159        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2160        group_size_hints: GroupSizeHints,
2161    ) {
2162        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2163            let old_finishing = mem::replace(
2164                finishing,
2165                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2166            );
2167            *self = HirRelationExpr::top_k(
2168                std::mem::replace(
2169                    self,
2170                    HirRelationExpr::Constant {
2171                        rows: vec![],
2172                        typ: SqlRelationType::new(Vec::new()),
2173                    },
2174                ),
2175                vec![],
2176                old_finishing.order_by,
2177                old_finishing.limit,
2178                old_finishing.offset,
2179                group_size_hints.limit_input_group_size,
2180            )
2181            .project(old_finishing.project);
2182        }
2183    }
2184
2185    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2186    ///
2187    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2188    /// parameter is not an HirScalarExpr anymore.)
2189    pub fn trivial_row_set_finishing_hir(
2190        arity: usize,
2191    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2192        RowSetFinishing {
2193            order_by: Vec::new(),
2194            limit: None,
2195            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2196            project: (0..arity).collect(),
2197        }
2198    }
2199
2200    /// True if the finishing does nothing to any result set.
2201    ///
2202    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2203    /// parameter is not an HirScalarExpr anymore.)
2204    pub fn is_trivial_row_set_finishing_hir(
2205        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2206        arity: usize,
2207    ) -> bool {
2208        rsf.limit.is_none()
2209            && rsf.order_by.is_empty()
2210            && rsf
2211                .offset
2212                .clone()
2213                .try_into_literal_int64()
2214                .is_ok_and(|o| o == 0)
2215            && rsf.project.iter().copied().eq(0..arity)
2216    }
2217
2218    /// The HirRelationExpr is considered potentially expensive if and only if
2219    /// at least one of the following conditions is true:
2220    ///
2221    ///  - It contains at least one CallTable or a Reduce operator.
2222    ///  - It contains at least one HirScalarExpr with a function call.
2223    ///
2224    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2225    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2226    pub fn could_run_expensive_function(&self) -> bool {
2227        let mut result = false;
2228        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2229            use HirRelationExpr::*;
2230            use HirScalarExpr::*;
2231
2232            self.visit_children(|scalar: &HirScalarExpr| {
2233                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2234                    result |= match scalar {
2235                        Column(..)
2236                        | Literal(..)
2237                        | CallUnmaterializable(..)
2238                        | If { .. }
2239                        | Parameter(..)
2240                        | Select(..)
2241                        | Exists(..) => false,
2242                        // Function calls are considered expensive
2243                        CallUnary { .. }
2244                        | CallBinary { .. }
2245                        | CallVariadic { .. }
2246                        | Windowing(..) => true,
2247                    };
2248                }) {
2249                    // Conservatively set `true` on RecursionLimitError.
2250                    result = true;
2251                }
2252            });
2253
2254            // CallTable has a table function; Reduce has an aggregate function.
2255            // Other constructs use MirScalarExpr to run a function
2256            result |= matches!(e, CallTable { .. } | Reduce { .. });
2257        }) {
2258            // Conservatively set `true` on RecursionLimitError.
2259            result = true;
2260        }
2261
2262        result
2263    }
2264
2265    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2266    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2267        let mut contains = false;
2268        self.visit_post(&mut |expr| {
2269            expr.visit_children(|expr: &HirScalarExpr| {
2270                contains = contains || expr.contains_temporal()
2271            })
2272        })?;
2273        Ok(contains)
2274    }
2275
2276    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2277    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2278        let mut contains = false;
2279        self.visit_post(&mut |expr| {
2280            expr.visit_children(|expr: &HirScalarExpr| {
2281                contains = contains || expr.contains_unmaterializable()
2282            })
2283        })?;
2284        Ok(contains)
2285    }
2286}
2287
2288impl CollectionPlan for HirRelationExpr {
2289    // !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2290    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2291    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2292        if let Self::Get {
2293            id: Id::Global(id), ..
2294        } = self
2295        {
2296            out.insert(*id);
2297        }
2298        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2299    }
2300}
2301
2302impl VisitChildren<Self> for HirRelationExpr {
2303    fn visit_children<F>(&self, mut f: F)
2304    where
2305        F: FnMut(&Self),
2306    {
2307        // subqueries of type HirRelationExpr might be wrapped in
2308        // Exists or Select variants within HirScalarExpr trees
2309        // attached at the current node, and we want to visit them as well
2310        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2311            #[allow(deprecated)]
2312            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2313                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2314                    f(expr.as_ref())
2315                }
2316                _ => (),
2317            });
2318        });
2319
2320        use HirRelationExpr::*;
2321        match self {
2322            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2323            Let {
2324                name: _,
2325                id: _,
2326                value,
2327                body,
2328            } => {
2329                f(value);
2330                f(body);
2331            }
2332            LetRec {
2333                limit: _,
2334                bindings,
2335                body,
2336            } => {
2337                for (_, _, value, _) in bindings.iter() {
2338                    f(value);
2339                }
2340                f(body);
2341            }
2342            Project { input, outputs: _ } => f(input),
2343            Map { input, scalars: _ } => {
2344                f(input);
2345            }
2346            CallTable { func: _, exprs: _ } => (),
2347            Filter {
2348                input,
2349                predicates: _,
2350            } => {
2351                f(input);
2352            }
2353            Join {
2354                left,
2355                right,
2356                on: _,
2357                kind: _,
2358            } => {
2359                f(left);
2360                f(right);
2361            }
2362            Reduce {
2363                input,
2364                group_key: _,
2365                aggregates: _,
2366                expected_group_size: _,
2367            } => {
2368                f(input);
2369            }
2370            Distinct { input }
2371            | TopK {
2372                input,
2373                group_key: _,
2374                order_key: _,
2375                limit: _,
2376                offset: _,
2377                expected_group_size: _,
2378            }
2379            | Negate { input }
2380            | Threshold { input } => {
2381                f(input);
2382            }
2383            Union { base, inputs } => {
2384                f(base);
2385                for input in inputs {
2386                    f(input);
2387                }
2388            }
2389        }
2390    }
2391
2392    fn visit_mut_children<F>(&mut self, mut f: F)
2393    where
2394        F: FnMut(&mut Self),
2395    {
2396        // subqueries of type HirRelationExpr might be wrapped in
2397        // Exists or Select variants within HirScalarExpr trees
2398        // attached at the current node, and we want to visit them as well
2399        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2400            #[allow(deprecated)]
2401            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2402                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2403                    f(expr.as_mut())
2404                }
2405                _ => (),
2406            });
2407        });
2408
2409        use HirRelationExpr::*;
2410        match self {
2411            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2412            Let {
2413                name: _,
2414                id: _,
2415                value,
2416                body,
2417            } => {
2418                f(value);
2419                f(body);
2420            }
2421            LetRec {
2422                limit: _,
2423                bindings,
2424                body,
2425            } => {
2426                for (_, _, value, _) in bindings.iter_mut() {
2427                    f(value);
2428                }
2429                f(body);
2430            }
2431            Project { input, outputs: _ } => f(input),
2432            Map { input, scalars: _ } => {
2433                f(input);
2434            }
2435            CallTable { func: _, exprs: _ } => (),
2436            Filter {
2437                input,
2438                predicates: _,
2439            } => {
2440                f(input);
2441            }
2442            Join {
2443                left,
2444                right,
2445                on: _,
2446                kind: _,
2447            } => {
2448                f(left);
2449                f(right);
2450            }
2451            Reduce {
2452                input,
2453                group_key: _,
2454                aggregates: _,
2455                expected_group_size: _,
2456            } => {
2457                f(input);
2458            }
2459            Distinct { input }
2460            | TopK {
2461                input,
2462                group_key: _,
2463                order_key: _,
2464                limit: _,
2465                offset: _,
2466                expected_group_size: _,
2467            }
2468            | Negate { input }
2469            | Threshold { input } => {
2470                f(input);
2471            }
2472            Union { base, inputs } => {
2473                f(base);
2474                for input in inputs {
2475                    f(input);
2476                }
2477            }
2478        }
2479    }
2480
2481    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2482    where
2483        F: FnMut(&Self) -> Result<(), E>,
2484        E: From<RecursionLimitError>,
2485    {
2486        // subqueries of type HirRelationExpr might be wrapped in
2487        // Exists or Select variants within HirScalarExpr trees
2488        // attached at the current node, and we want to visit them as well
2489        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2490            Visit::try_visit_post(expr, &mut |expr| match expr {
2491                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2492                    f(expr.as_ref())
2493                }
2494                _ => Ok(()),
2495            })
2496        })?;
2497
2498        use HirRelationExpr::*;
2499        match self {
2500            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2501            Let {
2502                name: _,
2503                id: _,
2504                value,
2505                body,
2506            } => {
2507                f(value)?;
2508                f(body)?;
2509            }
2510            LetRec {
2511                limit: _,
2512                bindings,
2513                body,
2514            } => {
2515                for (_, _, value, _) in bindings.iter() {
2516                    f(value)?;
2517                }
2518                f(body)?;
2519            }
2520            Project { input, outputs: _ } => f(input)?,
2521            Map { input, scalars: _ } => {
2522                f(input)?;
2523            }
2524            CallTable { func: _, exprs: _ } => (),
2525            Filter {
2526                input,
2527                predicates: _,
2528            } => {
2529                f(input)?;
2530            }
2531            Join {
2532                left,
2533                right,
2534                on: _,
2535                kind: _,
2536            } => {
2537                f(left)?;
2538                f(right)?;
2539            }
2540            Reduce {
2541                input,
2542                group_key: _,
2543                aggregates: _,
2544                expected_group_size: _,
2545            } => {
2546                f(input)?;
2547            }
2548            Distinct { input }
2549            | TopK {
2550                input,
2551                group_key: _,
2552                order_key: _,
2553                limit: _,
2554                offset: _,
2555                expected_group_size: _,
2556            }
2557            | Negate { input }
2558            | Threshold { input } => {
2559                f(input)?;
2560            }
2561            Union { base, inputs } => {
2562                f(base)?;
2563                for input in inputs {
2564                    f(input)?;
2565                }
2566            }
2567        }
2568        Ok(())
2569    }
2570
2571    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2572    where
2573        F: FnMut(&mut Self) -> Result<(), E>,
2574        E: From<RecursionLimitError>,
2575    {
2576        // subqueries of type HirRelationExpr might be wrapped in
2577        // Exists or Select variants within HirScalarExpr trees
2578        // attached at the current node, and we want to visit them as well
2579        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2580            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2581                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2582                    f(expr.as_mut())
2583                }
2584                _ => Ok(()),
2585            })
2586        })?;
2587
2588        use HirRelationExpr::*;
2589        match self {
2590            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2591            Let {
2592                name: _,
2593                id: _,
2594                value,
2595                body,
2596            } => {
2597                f(value)?;
2598                f(body)?;
2599            }
2600            LetRec {
2601                limit: _,
2602                bindings,
2603                body,
2604            } => {
2605                for (_, _, value, _) in bindings.iter_mut() {
2606                    f(value)?;
2607                }
2608                f(body)?;
2609            }
2610            Project { input, outputs: _ } => f(input)?,
2611            Map { input, scalars: _ } => {
2612                f(input)?;
2613            }
2614            CallTable { func: _, exprs: _ } => (),
2615            Filter {
2616                input,
2617                predicates: _,
2618            } => {
2619                f(input)?;
2620            }
2621            Join {
2622                left,
2623                right,
2624                on: _,
2625                kind: _,
2626            } => {
2627                f(left)?;
2628                f(right)?;
2629            }
2630            Reduce {
2631                input,
2632                group_key: _,
2633                aggregates: _,
2634                expected_group_size: _,
2635            } => {
2636                f(input)?;
2637            }
2638            Distinct { input }
2639            | TopK {
2640                input,
2641                group_key: _,
2642                order_key: _,
2643                limit: _,
2644                offset: _,
2645                expected_group_size: _,
2646            }
2647            | Negate { input }
2648            | Threshold { input } => {
2649                f(input)?;
2650            }
2651            Union { base, inputs } => {
2652                f(base)?;
2653                for input in inputs {
2654                    f(input)?;
2655                }
2656            }
2657        }
2658        Ok(())
2659    }
2660}
2661
2662impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2663    fn visit_children<F>(&self, mut f: F)
2664    where
2665        F: FnMut(&HirScalarExpr),
2666    {
2667        use HirRelationExpr::*;
2668        match self {
2669            Constant { rows: _, typ: _ }
2670            | Get { id: _, typ: _ }
2671            | Let {
2672                name: _,
2673                id: _,
2674                value: _,
2675                body: _,
2676            }
2677            | LetRec {
2678                limit: _,
2679                bindings: _,
2680                body: _,
2681            }
2682            | Project {
2683                input: _,
2684                outputs: _,
2685            } => (),
2686            Map { input: _, scalars } => {
2687                for scalar in scalars {
2688                    f(scalar);
2689                }
2690            }
2691            CallTable { func: _, exprs } => {
2692                for expr in exprs {
2693                    f(expr);
2694                }
2695            }
2696            Filter {
2697                input: _,
2698                predicates,
2699            } => {
2700                for predicate in predicates {
2701                    f(predicate);
2702                }
2703            }
2704            Join {
2705                left: _,
2706                right: _,
2707                on,
2708                kind: _,
2709            } => f(on),
2710            Reduce {
2711                input: _,
2712                group_key: _,
2713                aggregates,
2714                expected_group_size: _,
2715            } => {
2716                for aggregate in aggregates {
2717                    f(aggregate.expr.as_ref());
2718                }
2719            }
2720            TopK {
2721                input: _,
2722                group_key: _,
2723                order_key: _,
2724                limit,
2725                offset,
2726                expected_group_size: _,
2727            } => {
2728                if let Some(limit) = limit {
2729                    f(limit)
2730                }
2731                f(offset)
2732            }
2733            Distinct { input: _ }
2734            | Negate { input: _ }
2735            | Threshold { input: _ }
2736            | Union { base: _, inputs: _ } => (),
2737        }
2738    }
2739
2740    fn visit_mut_children<F>(&mut self, mut f: F)
2741    where
2742        F: FnMut(&mut HirScalarExpr),
2743    {
2744        use HirRelationExpr::*;
2745        match self {
2746            Constant { rows: _, typ: _ }
2747            | Get { id: _, typ: _ }
2748            | Let {
2749                name: _,
2750                id: _,
2751                value: _,
2752                body: _,
2753            }
2754            | LetRec {
2755                limit: _,
2756                bindings: _,
2757                body: _,
2758            }
2759            | Project {
2760                input: _,
2761                outputs: _,
2762            } => (),
2763            Map { input: _, scalars } => {
2764                for scalar in scalars {
2765                    f(scalar);
2766                }
2767            }
2768            CallTable { func: _, exprs } => {
2769                for expr in exprs {
2770                    f(expr);
2771                }
2772            }
2773            Filter {
2774                input: _,
2775                predicates,
2776            } => {
2777                for predicate in predicates {
2778                    f(predicate);
2779                }
2780            }
2781            Join {
2782                left: _,
2783                right: _,
2784                on,
2785                kind: _,
2786            } => f(on),
2787            Reduce {
2788                input: _,
2789                group_key: _,
2790                aggregates,
2791                expected_group_size: _,
2792            } => {
2793                for aggregate in aggregates {
2794                    f(aggregate.expr.as_mut());
2795                }
2796            }
2797            TopK {
2798                input: _,
2799                group_key: _,
2800                order_key: _,
2801                limit,
2802                offset,
2803                expected_group_size: _,
2804            } => {
2805                if let Some(limit) = limit {
2806                    f(limit)
2807                }
2808                f(offset)
2809            }
2810            Distinct { input: _ }
2811            | Negate { input: _ }
2812            | Threshold { input: _ }
2813            | Union { base: _, inputs: _ } => (),
2814        }
2815    }
2816
2817    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2818    where
2819        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2820        E: From<RecursionLimitError>,
2821    {
2822        use HirRelationExpr::*;
2823        match self {
2824            Constant { rows: _, typ: _ }
2825            | Get { id: _, typ: _ }
2826            | Let {
2827                name: _,
2828                id: _,
2829                value: _,
2830                body: _,
2831            }
2832            | LetRec {
2833                limit: _,
2834                bindings: _,
2835                body: _,
2836            }
2837            | Project {
2838                input: _,
2839                outputs: _,
2840            } => (),
2841            Map { input: _, scalars } => {
2842                for scalar in scalars {
2843                    f(scalar)?;
2844                }
2845            }
2846            CallTable { func: _, exprs } => {
2847                for expr in exprs {
2848                    f(expr)?;
2849                }
2850            }
2851            Filter {
2852                input: _,
2853                predicates,
2854            } => {
2855                for predicate in predicates {
2856                    f(predicate)?;
2857                }
2858            }
2859            Join {
2860                left: _,
2861                right: _,
2862                on,
2863                kind: _,
2864            } => f(on)?,
2865            Reduce {
2866                input: _,
2867                group_key: _,
2868                aggregates,
2869                expected_group_size: _,
2870            } => {
2871                for aggregate in aggregates {
2872                    f(aggregate.expr.as_ref())?;
2873                }
2874            }
2875            TopK {
2876                input: _,
2877                group_key: _,
2878                order_key: _,
2879                limit,
2880                offset,
2881                expected_group_size: _,
2882            } => {
2883                if let Some(limit) = limit {
2884                    f(limit)?
2885                }
2886                f(offset)?
2887            }
2888            Distinct { input: _ }
2889            | Negate { input: _ }
2890            | Threshold { input: _ }
2891            | Union { base: _, inputs: _ } => (),
2892        }
2893        Ok(())
2894    }
2895
2896    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2897    where
2898        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2899        E: From<RecursionLimitError>,
2900    {
2901        use HirRelationExpr::*;
2902        match self {
2903            Constant { rows: _, typ: _ }
2904            | Get { id: _, typ: _ }
2905            | Let {
2906                name: _,
2907                id: _,
2908                value: _,
2909                body: _,
2910            }
2911            | LetRec {
2912                limit: _,
2913                bindings: _,
2914                body: _,
2915            }
2916            | Project {
2917                input: _,
2918                outputs: _,
2919            } => (),
2920            Map { input: _, scalars } => {
2921                for scalar in scalars {
2922                    f(scalar)?;
2923                }
2924            }
2925            CallTable { func: _, exprs } => {
2926                for expr in exprs {
2927                    f(expr)?;
2928                }
2929            }
2930            Filter {
2931                input: _,
2932                predicates,
2933            } => {
2934                for predicate in predicates {
2935                    f(predicate)?;
2936                }
2937            }
2938            Join {
2939                left: _,
2940                right: _,
2941                on,
2942                kind: _,
2943            } => f(on)?,
2944            Reduce {
2945                input: _,
2946                group_key: _,
2947                aggregates,
2948                expected_group_size: _,
2949            } => {
2950                for aggregate in aggregates {
2951                    f(aggregate.expr.as_mut())?;
2952                }
2953            }
2954            TopK {
2955                input: _,
2956                group_key: _,
2957                order_key: _,
2958                limit,
2959                offset,
2960                expected_group_size: _,
2961            } => {
2962                if let Some(limit) = limit {
2963                    f(limit)?
2964                }
2965                f(offset)?
2966            }
2967            Distinct { input: _ }
2968            | Negate { input: _ }
2969            | Threshold { input: _ }
2970            | Union { base: _, inputs: _ } => (),
2971        }
2972        Ok(())
2973    }
2974}
2975
2976impl HirScalarExpr {
2977    pub fn name(&self) -> Option<Arc<str>> {
2978        use HirScalarExpr::*;
2979        match self {
2980            Column(_, name)
2981            | Parameter(_, name)
2982            | Literal(_, _, name)
2983            | CallUnmaterializable(_, name)
2984            | CallUnary { name, .. }
2985            | CallBinary { name, .. }
2986            | CallVariadic { name, .. }
2987            | If { name, .. }
2988            | Exists(_, name)
2989            | Select(_, name)
2990            | Windowing(_, name) => name.0.clone(),
2991        }
2992    }
2993
2994    /// Replaces any parameter references in the expression with the
2995    /// corresponding datum in `params`.
2996    pub fn bind_parameters(
2997        &mut self,
2998        scx: &StatementContext,
2999        lifetime: QueryLifetime,
3000        params: &Params,
3001    ) -> Result<(), PlanError> {
3002        #[allow(deprecated)]
3003        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3004            if let HirScalarExpr::Parameter(n, name) = e {
3005                let datum = match params.datums.iter().nth(*n - 1) {
3006                    None => return Err(PlanError::UnknownParameter(*n)),
3007                    Some(datum) => datum,
3008                };
3009                let scalar_type = &params.execute_types[*n - 1];
3010                let row = Row::pack([datum]);
3011                let column_type = scalar_type.clone().nullable(datum.is_null());
3012
3013                let name = if let Some(name) = &name.0 {
3014                    Some(Arc::clone(name))
3015                } else {
3016                    Some(Arc::from(format!("${n}")))
3017                };
3018
3019                let qcx = QueryContext::root(scx, lifetime);
3020                let ecx = execute_expr_context(&qcx);
3021
3022                *e = plan_cast(
3023                    &ecx,
3024                    *EXECUTE_CAST_CONTEXT,
3025                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3026                    &params.expected_types[*n - 1],
3027                )
3028                .expect("checked in plan_params");
3029            }
3030            Ok(())
3031        })
3032    }
3033
3034    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3035    /// replaced with the corresponding expression fragment from `params` rather
3036    /// than a datum.
3037    ///
3038    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3039    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3040    /// in `self` that refer to invalid indices of `params` will cause a panic.
3041    ///
3042    /// Column references in parameters will be corrected to account for the
3043    /// depth at which they are spliced.
3044    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3045        #[allow(deprecated)]
3046        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3047                                                        e: &mut HirScalarExpr|
3048         -> Result<(), ()> {
3049            if let HirScalarExpr::Parameter(i, _name) = e {
3050                *e = params[*i - 1].clone();
3051                // Correct any column references in the parameter expression for
3052                // its new depth.
3053                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3054                    if col.level >= d {
3055                        col.level += depth
3056                    }
3057                });
3058            }
3059            Ok(())
3060        });
3061    }
3062
3063    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3064    pub fn contains_temporal(&self) -> bool {
3065        let mut contains = false;
3066        #[allow(deprecated)]
3067        self.visit_post_nolimit(&mut |e| {
3068            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3069                contains = true;
3070            }
3071        });
3072        contains
3073    }
3074
3075    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3076    pub fn contains_unmaterializable(&self) -> bool {
3077        let mut contains = false;
3078        #[allow(deprecated)]
3079        self.visit_post_nolimit(&mut |e| {
3080            if let Self::CallUnmaterializable(_, _) = e {
3081                contains = true;
3082            }
3083        });
3084        contains
3085    }
3086
3087    /// Constructs an unnamed column reference in the current scope.
3088    /// Use [`HirScalarExpr::named_column`] when a name is known.
3089    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3090    pub fn column(index: usize) -> HirScalarExpr {
3091        HirScalarExpr::Column(
3092            ColumnRef {
3093                level: 0,
3094                column: index,
3095            },
3096            TreatAsEqual(None),
3097        )
3098    }
3099
3100    /// Constructs an unnamed column reference.
3101    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3102        HirScalarExpr::Column(cr, TreatAsEqual(None))
3103    }
3104
3105    /// Constructs a named column reference.
3106    /// Names are interned by a `NameManager`.
3107    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3108        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3109    }
3110
3111    pub fn parameter(n: usize) -> HirScalarExpr {
3112        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3113    }
3114
3115    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3116        let col_type = scalar_type.nullable(datum.is_null());
3117        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3118        let row = Row::pack([datum]);
3119        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3120    }
3121
3122    pub fn literal_true() -> HirScalarExpr {
3123        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3124    }
3125
3126    pub fn literal_false() -> HirScalarExpr {
3127        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3128    }
3129
3130    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3131        HirScalarExpr::literal(Datum::Null, scalar_type)
3132    }
3133
3134    pub fn literal_1d_array(
3135        datums: Vec<Datum>,
3136        element_scalar_type: SqlScalarType,
3137    ) -> Result<HirScalarExpr, PlanError> {
3138        let scalar_type = match element_scalar_type {
3139            SqlScalarType::Array(_) => {
3140                sql_bail!("cannot build array from array type");
3141            }
3142            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3143        };
3144
3145        let mut row = Row::default();
3146        row.packer()
3147            .try_push_array(
3148                &[ArrayDimension {
3149                    lower_bound: 1,
3150                    length: datums.len(),
3151                }],
3152                datums,
3153            )
3154            .expect("array constructed to be valid");
3155
3156        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3157    }
3158
3159    pub fn as_literal(&self) -> Option<Datum<'_>> {
3160        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3161            Some(row.unpack_first())
3162        } else {
3163            None
3164        }
3165    }
3166
3167    pub fn is_literal_true(&self) -> bool {
3168        Some(Datum::True) == self.as_literal()
3169    }
3170
3171    pub fn is_literal_false(&self) -> bool {
3172        Some(Datum::False) == self.as_literal()
3173    }
3174
3175    pub fn is_literal_null(&self) -> bool {
3176        Some(Datum::Null) == self.as_literal()
3177    }
3178
3179    /// Return true iff `self` consists only of literals, materializable function calls, and
3180    /// if-else statements.
3181    pub fn is_constant(&self) -> bool {
3182        let mut worklist = vec![self];
3183        while let Some(expr) = worklist.pop() {
3184            match expr {
3185                Self::Literal(..) => {
3186                    // leaf node, do nothing
3187                }
3188                Self::CallUnary { expr, .. } => {
3189                    worklist.push(expr);
3190                }
3191                Self::CallBinary {
3192                    func: _,
3193                    expr1,
3194                    expr2,
3195                    name: _,
3196                } => {
3197                    worklist.push(expr1);
3198                    worklist.push(expr2);
3199                }
3200                Self::CallVariadic {
3201                    func: _,
3202                    exprs,
3203                    name: _,
3204                } => {
3205                    worklist.extend(exprs.iter());
3206                }
3207                // (CallUnmaterializable is not allowed)
3208                Self::If {
3209                    cond,
3210                    then,
3211                    els,
3212                    name: _,
3213                } => {
3214                    worklist.push(cond);
3215                    worklist.push(then);
3216                    worklist.push(els);
3217                }
3218                _ => {
3219                    return false; // Any other node makes `self` non-constant.
3220                }
3221            }
3222        }
3223        true
3224    }
3225
3226    pub fn call_unary(self, func: UnaryFunc) -> Self {
3227        HirScalarExpr::CallUnary {
3228            func,
3229            expr: Box::new(self),
3230            name: NameMetadata::default(),
3231        }
3232    }
3233
3234    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3235        HirScalarExpr::CallBinary {
3236            func: func.into(),
3237            expr1: Box::new(self),
3238            expr2: Box::new(other),
3239            name: NameMetadata::default(),
3240        }
3241    }
3242
3243    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3244        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3245    }
3246
3247    pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3248        HirScalarExpr::CallVariadic {
3249            func,
3250            exprs,
3251            name: NameMetadata::default(),
3252        }
3253    }
3254
3255    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3256        HirScalarExpr::If {
3257            cond: Box::new(cond),
3258            then: Box::new(then),
3259            els: Box::new(els),
3260            name: NameMetadata::default(),
3261        }
3262    }
3263
3264    pub fn windowing(expr: WindowExpr) -> Self {
3265        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3266    }
3267
3268    pub fn or(self, other: Self) -> Self {
3269        HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3270    }
3271
3272    pub fn and(self, other: Self) -> Self {
3273        HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3274    }
3275
3276    pub fn not(self) -> Self {
3277        self.call_unary(UnaryFunc::Not(func::Not))
3278    }
3279
3280    pub fn call_is_null(self) -> Self {
3281        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3282    }
3283
3284    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3285    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3286        match args.len() {
3287            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3288            1 => args.swap_remove(0),
3289            _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3290        }
3291    }
3292
3293    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3294    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3295        match args.len() {
3296            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3297            1 => args.swap_remove(0),
3298            _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3299        }
3300    }
3301
3302    pub fn take(&mut self) -> Self {
3303        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3304    }
3305
3306    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3307    /// Visits the column references in this scalar expression.
3308    ///
3309    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3310    /// which will be incremented with each subquery entered and presented to the supplied
3311    /// function `f`.
3312    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3313    where
3314        F: FnMut(usize, &ColumnRef),
3315    {
3316        #[allow(deprecated)]
3317        let _ = self.visit_recursively(depth, &mut |depth: usize,
3318                                                    e: &HirScalarExpr|
3319         -> Result<(), ()> {
3320            if let HirScalarExpr::Column(col, _name) = e {
3321                f(depth, col)
3322            }
3323            Ok(())
3324        });
3325    }
3326
3327    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3328    /// Like `visit_columns`, but permits mutating the column references.
3329    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3330    where
3331        F: FnMut(usize, &mut ColumnRef),
3332    {
3333        #[allow(deprecated)]
3334        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3335                                                        e: &mut HirScalarExpr|
3336         -> Result<(), ()> {
3337            if let HirScalarExpr::Column(col, _name) = e {
3338                f(depth, col)
3339            }
3340            Ok(())
3341        });
3342    }
3343
3344    /// Visits those column references in this scalar expression that refer to the root
3345    /// level. These include column references that are at the root level, as well as column
3346    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3347    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3348    /// "root level" to be `self`'s level.)
3349    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3350    where
3351        F: FnMut(usize),
3352    {
3353        #[allow(deprecated)]
3354        let _ = self.visit_recursively(0, &mut |depth: usize,
3355                                                e: &HirScalarExpr|
3356         -> Result<(), ()> {
3357            if let HirScalarExpr::Column(col, _name) = e {
3358                if col.level == depth {
3359                    f(col.column)
3360                }
3361            }
3362            Ok(())
3363        });
3364    }
3365
3366    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3367    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3368    where
3369        F: FnMut(&mut usize),
3370    {
3371        #[allow(deprecated)]
3372        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3373                                                    e: &mut HirScalarExpr|
3374         -> Result<(), ()> {
3375            if let HirScalarExpr::Column(col, _name) = e {
3376                if col.level == depth {
3377                    f(&mut col.column)
3378                }
3379            }
3380            Ok(())
3381        });
3382    }
3383
3384    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3385    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3386    /// in them. It takes the current depth of the expression and increases it when
3387    /// entering a subquery.
3388    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3389    where
3390        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3391    {
3392        match self {
3393            HirScalarExpr::Literal(..)
3394            | HirScalarExpr::Parameter(..)
3395            | HirScalarExpr::CallUnmaterializable(..)
3396            | HirScalarExpr::Column(..) => (),
3397            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3398            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3399                expr1.visit_recursively(depth, f)?;
3400                expr2.visit_recursively(depth, f)?;
3401            }
3402            HirScalarExpr::CallVariadic { exprs, .. } => {
3403                for expr in exprs {
3404                    expr.visit_recursively(depth, f)?;
3405                }
3406            }
3407            HirScalarExpr::If {
3408                cond,
3409                then,
3410                els,
3411                name: _,
3412            } => {
3413                cond.visit_recursively(depth, f)?;
3414                then.visit_recursively(depth, f)?;
3415                els.visit_recursively(depth, f)?;
3416            }
3417            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3418                #[allow(deprecated)]
3419                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3420                    e.visit_recursively(depth, f)
3421                })?;
3422            }
3423            HirScalarExpr::Windowing(expr, _name) => {
3424                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3425            }
3426        }
3427        f(depth, self)
3428    }
3429
3430    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3431    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3432    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3433    where
3434        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3435    {
3436        match self {
3437            HirScalarExpr::Literal(..)
3438            | HirScalarExpr::Parameter(..)
3439            | HirScalarExpr::CallUnmaterializable(..)
3440            | HirScalarExpr::Column(..) => (),
3441            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3442            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3443                expr1.visit_recursively_mut(depth, f)?;
3444                expr2.visit_recursively_mut(depth, f)?;
3445            }
3446            HirScalarExpr::CallVariadic { exprs, .. } => {
3447                for expr in exprs {
3448                    expr.visit_recursively_mut(depth, f)?;
3449                }
3450            }
3451            HirScalarExpr::If {
3452                cond,
3453                then,
3454                els,
3455                name: _,
3456            } => {
3457                cond.visit_recursively_mut(depth, f)?;
3458                then.visit_recursively_mut(depth, f)?;
3459                els.visit_recursively_mut(depth, f)?;
3460            }
3461            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3462                #[allow(deprecated)]
3463                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3464                    e.visit_recursively_mut(depth, f)
3465                })?;
3466            }
3467            HirScalarExpr::Windowing(expr, _name) => {
3468                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3469            }
3470        }
3471        f(depth, self)
3472    }
3473
3474    /// Attempts to simplify self into a literal.
3475    ///
3476    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3477    /// an evaluation error occurs during simplification, or if self contains
3478    /// - a subquery
3479    /// - a column reference to an outer level
3480    /// - a parameter
3481    /// - a window function call
3482    fn simplify_to_literal(self) -> Option<Row> {
3483        let mut expr = self.lower_uncorrelated().ok()?;
3484        expr.reduce(&[]);
3485        match expr {
3486            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3487            _ => None,
3488        }
3489    }
3490
3491    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3492    /// or an evaluation error occurs during simplification), it returns
3493    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3494    ///
3495    /// The returned error is an _internal_ error if the expression contains
3496    /// - a subquery
3497    /// - a column reference to an outer level
3498    /// - a parameter
3499    /// - a window function call
3500    ///
3501    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3502    /// msg.
3503    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3504        let mut expr = self.lower_uncorrelated().map_err(|err| {
3505            PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3506        })?;
3507        expr.reduce(&[]);
3508        match expr {
3509            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3510            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3511                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3512            ),
3513            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3514                "Not a constant".to_string(),
3515            )),
3516        }
3517    }
3518
3519    /// Attempts to simplify this expression to a literal 64-bit integer.
3520    ///
3521    /// Returns `None` if this expression cannot be simplified, e.g. because it
3522    /// contains non-literal values.
3523    ///
3524    /// # Panics
3525    ///
3526    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3527    pub fn into_literal_int64(self) -> Option<i64> {
3528        self.simplify_to_literal().and_then(|row| {
3529            let datum = row.unpack_first();
3530            if datum.is_null() {
3531                None
3532            } else {
3533                Some(datum.unwrap_int64())
3534            }
3535        })
3536    }
3537
3538    /// Attempts to simplify this expression to a literal string.
3539    ///
3540    /// Returns `None` if this expression cannot be simplified, e.g. because it
3541    /// contains non-literal values.
3542    ///
3543    /// # Panics
3544    ///
3545    /// Panics if this expression does not have type [`SqlScalarType::String`].
3546    pub fn into_literal_string(self) -> Option<String> {
3547        self.simplify_to_literal().and_then(|row| {
3548            let datum = row.unpack_first();
3549            if datum.is_null() {
3550                None
3551            } else {
3552                Some(datum.unwrap_str().to_owned())
3553            }
3554        })
3555    }
3556
3557    /// Attempts to simplify this expression to a literal MzTimestamp.
3558    ///
3559    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3560    /// simplified, e.g. because it contains non-literal values or a cast fails.
3561    ///
3562    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3563    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3564    /// See `try_into_literal_int64` as an example.
3565    ///
3566    /// # Panics
3567    ///
3568    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3569    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3570        self.simplify_to_literal().and_then(|row| {
3571            let datum = row.unpack_first();
3572            if datum.is_null() {
3573                None
3574            } else {
3575                Some(datum.unwrap_mz_timestamp())
3576            }
3577        })
3578    }
3579
3580    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3581    /// returns it as an i64.
3582    ///
3583    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3584    /// - it's not a constant expression (as determined by `is_constant`)
3585    /// - evaluates to null
3586    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3587    ///
3588    /// # Panics
3589    ///
3590    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3591    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3592        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3593        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3594        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3595        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3596        // preconditions also of all the other into_literal_... functions.)
3597        if !self.is_constant() {
3598            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3599                "Expected a constant expression, got {}",
3600                self
3601            )));
3602        }
3603        self.clone()
3604            .simplify_to_literal_with_result()
3605            .and_then(|row| {
3606                let datum = row.unpack_first();
3607                if datum.is_null() {
3608                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3609                        "Expected an expression that evaluates to a non-null value, got {}",
3610                        self
3611                    )))
3612                } else {
3613                    Ok(datum.unwrap_int64())
3614                }
3615            })
3616    }
3617
3618    pub fn contains_parameters(&self) -> bool {
3619        let mut contains_parameters = false;
3620        #[allow(deprecated)]
3621        let _ = self.visit_recursively(0, &mut |_depth: usize,
3622                                                expr: &HirScalarExpr|
3623         -> Result<(), ()> {
3624            if let HirScalarExpr::Parameter(..) = expr {
3625                contains_parameters = true;
3626            }
3627            Ok(())
3628        });
3629        contains_parameters
3630    }
3631}
3632
3633impl VisitChildren<Self> for HirScalarExpr {
3634    fn visit_children<F>(&self, mut f: F)
3635    where
3636        F: FnMut(&Self),
3637    {
3638        use HirScalarExpr::*;
3639        match self {
3640            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3641            CallUnary { expr, .. } => f(expr),
3642            CallBinary { expr1, expr2, .. } => {
3643                f(expr1);
3644                f(expr2);
3645            }
3646            CallVariadic { exprs, .. } => {
3647                for expr in exprs {
3648                    f(expr);
3649                }
3650            }
3651            If {
3652                cond,
3653                then,
3654                els,
3655                name: _,
3656            } => {
3657                f(cond);
3658                f(then);
3659                f(els);
3660            }
3661            Exists(..) | Select(..) => (),
3662            Windowing(expr, _name) => expr.visit_children(f),
3663        }
3664    }
3665
3666    fn visit_mut_children<F>(&mut self, mut f: F)
3667    where
3668        F: FnMut(&mut Self),
3669    {
3670        use HirScalarExpr::*;
3671        match self {
3672            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3673            CallUnary { expr, .. } => f(expr),
3674            CallBinary { expr1, expr2, .. } => {
3675                f(expr1);
3676                f(expr2);
3677            }
3678            CallVariadic { exprs, .. } => {
3679                for expr in exprs {
3680                    f(expr);
3681                }
3682            }
3683            If {
3684                cond,
3685                then,
3686                els,
3687                name: _,
3688            } => {
3689                f(cond);
3690                f(then);
3691                f(els);
3692            }
3693            Exists(..) | Select(..) => (),
3694            Windowing(expr, _name) => expr.visit_mut_children(f),
3695        }
3696    }
3697
3698    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3699    where
3700        F: FnMut(&Self) -> Result<(), E>,
3701        E: From<RecursionLimitError>,
3702    {
3703        use HirScalarExpr::*;
3704        match self {
3705            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3706            CallUnary { expr, .. } => f(expr)?,
3707            CallBinary { expr1, expr2, .. } => {
3708                f(expr1)?;
3709                f(expr2)?;
3710            }
3711            CallVariadic { exprs, .. } => {
3712                for expr in exprs {
3713                    f(expr)?;
3714                }
3715            }
3716            If {
3717                cond,
3718                then,
3719                els,
3720                name: _,
3721            } => {
3722                f(cond)?;
3723                f(then)?;
3724                f(els)?;
3725            }
3726            Exists(..) | Select(..) => (),
3727            Windowing(expr, _name) => expr.try_visit_children(f)?,
3728        }
3729        Ok(())
3730    }
3731
3732    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3733    where
3734        F: FnMut(&mut Self) -> Result<(), E>,
3735        E: From<RecursionLimitError>,
3736    {
3737        use HirScalarExpr::*;
3738        match self {
3739            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3740            CallUnary { expr, .. } => f(expr)?,
3741            CallBinary { expr1, expr2, .. } => {
3742                f(expr1)?;
3743                f(expr2)?;
3744            }
3745            CallVariadic { exprs, .. } => {
3746                for expr in exprs {
3747                    f(expr)?;
3748                }
3749            }
3750            If {
3751                cond,
3752                then,
3753                els,
3754                name: _,
3755            } => {
3756                f(cond)?;
3757                f(then)?;
3758                f(els)?;
3759            }
3760            Exists(..) | Select(..) => (),
3761            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3762        }
3763        Ok(())
3764    }
3765}
3766
3767impl AbstractExpr for HirScalarExpr {
3768    type Type = SqlColumnType;
3769
3770    fn typ(
3771        &self,
3772        outers: &[SqlRelationType],
3773        inner: &SqlRelationType,
3774        params: &BTreeMap<usize, SqlScalarType>,
3775    ) -> Self::Type {
3776        stack::maybe_grow(|| match self {
3777            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3778                if *level == 0 {
3779                    inner.column_types[*column].clone()
3780                } else {
3781                    outers[*level - 1].column_types[*column].clone()
3782                }
3783            }
3784            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3785            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3786            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3787            HirScalarExpr::CallUnary {
3788                expr,
3789                func,
3790                name: _,
3791            } => func.output_type(expr.typ(outers, inner, params)),
3792            HirScalarExpr::CallBinary {
3793                expr1,
3794                expr2,
3795                func,
3796                name: _,
3797            } => func.output_type(
3798                expr1.typ(outers, inner, params),
3799                expr2.typ(outers, inner, params),
3800            ),
3801            HirScalarExpr::CallVariadic {
3802                exprs,
3803                func,
3804                name: _,
3805            } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3806            HirScalarExpr::If {
3807                cond: _,
3808                then,
3809                els,
3810                name: _,
3811            } => {
3812                let then_type = then.typ(outers, inner, params);
3813                let else_type = els.typ(outers, inner, params);
3814                then_type.union(&else_type).unwrap()
3815            }
3816            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3817            HirScalarExpr::Select(expr, _name) => {
3818                let mut outers = outers.to_vec();
3819                outers.insert(0, inner.clone());
3820                expr.typ(&outers, params)
3821                    .column_types
3822                    .into_element()
3823                    .nullable(true)
3824            }
3825            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3826        })
3827    }
3828}
3829
3830impl AggregateExpr {
3831    pub fn typ(
3832        &self,
3833        outers: &[SqlRelationType],
3834        inner: &SqlRelationType,
3835        params: &BTreeMap<usize, SqlScalarType>,
3836    ) -> SqlColumnType {
3837        self.func.output_type(self.expr.typ(outers, inner, params))
3838    }
3839
3840    /// Returns whether the expression is COUNT(*) or not.  Note that
3841    /// when we define the count builtin in sql::func, we convert
3842    /// COUNT(*) to COUNT(true), making it indistinguishable from
3843    /// literal COUNT(true), but we prefer to consider this as the
3844    /// former.
3845    ///
3846    /// (MIR has the same `is_count_asterisk`.)
3847    pub fn is_count_asterisk(&self) -> bool {
3848        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3849    }
3850}