Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25use mz_expr::func::variadic::{And, Or};
26pub use mz_expr::{
27    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
28};
29use mz_ore::collections::CollectionExt;
30use mz_ore::error::ErrorExt;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::separated;
33use mz_ore::treat_as_equal::TreatAsEqual;
34use mz_ore::{soft_assert_or_log, stack};
35use mz_repr::adt::array::ArrayDimension;
36use mz_repr::adt::numeric::NumericMaxScale;
37use mz_repr::*;
38use serde::{Deserialize, Serialize};
39
40use crate::plan::error::PlanError;
41use crate::plan::query::{
42    EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context, offset_into_value,
43};
44use crate::plan::typeconv::{self, CastContext, plan_cast};
45use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
46
47use super::plan_utils::GroupSizeHints;
48
49#[allow(missing_debug_implementations)]
50pub struct Hir;
51
52impl IR for Hir {
53    type Relation = HirRelationExpr;
54    type Scalar = HirScalarExpr;
55}
56
57impl AlgExcept for Hir {
58    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
59        if *all {
60            let rhs = rhs.negate();
61            HirRelationExpr::union(lhs, rhs).threshold()
62        } else {
63            let lhs = lhs.distinct();
64            let rhs = rhs.distinct().negate();
65            HirRelationExpr::union(lhs, rhs).threshold()
66        }
67    }
68
69    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
70        let mut result = None;
71
72        use HirRelationExpr::*;
73        if let Threshold { input } = expr {
74            if let Union { base: lhs, inputs } = input.as_ref() {
75                if let [rhs] = &inputs[..] {
76                    if let Negate { input: rhs } = rhs {
77                        match (lhs.as_ref(), rhs.as_ref()) {
78                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
79                                let all = false;
80                                let lhs = lhs.as_ref();
81                                let rhs = rhs.as_ref();
82                                result = Some(Except { all, lhs, rhs })
83                            }
84                            (lhs, rhs) => {
85                                let all = true;
86                                result = Some(Except { all, lhs, rhs })
87                            }
88                        }
89                    }
90                }
91            }
92        }
93
94        result
95    }
96}
97
98#[derive(
99    Debug,
100    Clone,
101    PartialEq,
102    Eq,
103    PartialOrd,
104    Ord,
105    Hash,
106    Serialize,
107    Deserialize
108)]
109/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
110pub enum HirRelationExpr {
111    Constant {
112        rows: Vec<Row>,
113        typ: SqlRelationType,
114    },
115    Get {
116        id: mz_expr::Id,
117        typ: SqlRelationType,
118    },
119    /// Mutually recursive CTE
120    LetRec {
121        /// Maximum number of iterations to evaluate. If None, then there is no limit.
122        limit: Option<LetRecLimit>,
123        /// List of bindings all of which are in scope of each other.
124        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
125        /// Result of the AST node.
126        body: Box<HirRelationExpr>,
127    },
128    /// CTE
129    Let {
130        name: String,
131        /// The identifier to be used in `Get` variants to retrieve `value`.
132        id: mz_expr::LocalId,
133        /// The collection to be bound to `name`.
134        value: Box<HirRelationExpr>,
135        /// The result of the `Let`, evaluated with `name` bound to `value`.
136        body: Box<HirRelationExpr>,
137    },
138    Project {
139        input: Box<HirRelationExpr>,
140        outputs: Vec<usize>,
141    },
142    Map {
143        input: Box<HirRelationExpr>,
144        scalars: Vec<HirScalarExpr>,
145    },
146    CallTable {
147        func: TableFunc,
148        exprs: Vec<HirScalarExpr>,
149    },
150    Filter {
151        input: Box<HirRelationExpr>,
152        predicates: Vec<HirScalarExpr>,
153    },
154    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
155    /// joins away into more primitive exprs
156    Join {
157        left: Box<HirRelationExpr>,
158        right: Box<HirRelationExpr>,
159        on: HirScalarExpr,
160        kind: JoinKind,
161    },
162    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
163    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
164    /// rows
165    Reduce {
166        input: Box<HirRelationExpr>,
167        group_key: Vec<usize>,
168        aggregates: Vec<AggregateExpr>,
169        expected_group_size: Option<u64>,
170    },
171    Distinct {
172        input: Box<HirRelationExpr>,
173    },
174    /// Groups and orders within each group, limiting output.
175    TopK {
176        /// The source collection.
177        input: Box<HirRelationExpr>,
178        /// Column indices used to form groups.
179        group_key: Vec<usize>,
180        /// Column indices used to order rows within groups.
181        order_key: Vec<ColumnOrder>,
182        /// Number of records to retain.
183        /// It is of SqlScalarType::Int64.
184        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
185        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
186        /// is better for Postgres compat. This is because if there is a $1 here, then when external
187        /// tools `describe` the prepared statement, they discover this type. If what they find
188        /// were UInt64, then they might have trouble calling the prepared statement, because the
189        /// unsigned types are non-standard, and also don't exist even in Postgres.)
190        limit: Option<HirScalarExpr>,
191        /// Number of records to skip.
192        /// It is of SqlScalarType::Int64.
193        /// This can contain parameters at first, but by the time we reach lowering, this should
194        /// already be simply a Literal.
195        offset: HirScalarExpr,
196        /// User-supplied hint: how many rows will have the same group key.
197        expected_group_size: Option<u64>,
198    },
199    Negate {
200        input: Box<HirRelationExpr>,
201    },
202    /// Keep rows from a dataflow where the row counts are positive.
203    Threshold {
204        input: Box<HirRelationExpr>,
205    },
206    Union {
207        base: Box<HirRelationExpr>,
208        inputs: Vec<HirRelationExpr>,
209    },
210}
211
212/// Stored column metadata.
213pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
214
215#[derive(
216    Debug,
217    Clone,
218    PartialEq,
219    Eq,
220    PartialOrd,
221    Ord,
222    Hash,
223    Serialize,
224    Deserialize
225)]
226/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
227pub enum HirScalarExpr {
228    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
229    /// variable could refer to a column of the current input, or to a column of an outer relation.
230    /// We use ColumnRef to denote the difference.
231    Column(ColumnRef, NameMetadata),
232    Parameter(usize, NameMetadata),
233    Literal(Row, SqlColumnType, NameMetadata),
234    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
235    CallUnary {
236        func: UnaryFunc,
237        expr: Box<HirScalarExpr>,
238        name: NameMetadata,
239    },
240    CallBinary {
241        func: BinaryFunc,
242        expr1: Box<HirScalarExpr>,
243        expr2: Box<HirScalarExpr>,
244        name: NameMetadata,
245    },
246    CallVariadic {
247        func: VariadicFunc,
248        exprs: Vec<HirScalarExpr>,
249        name: NameMetadata,
250    },
251    If {
252        cond: Box<HirScalarExpr>,
253        then: Box<HirScalarExpr>,
254        els: Box<HirScalarExpr>,
255        name: NameMetadata,
256    },
257    /// Returns true if `expr` returns any rows
258    Exists(Box<HirRelationExpr>, NameMetadata),
259    /// Given `expr` with arity 1. If expr returns:
260    /// * 0 rows, return NULL
261    /// * 1 row, return the value of that row
262    /// * >1 rows, we return an error
263    Select(Box<HirRelationExpr>, NameMetadata),
264    Windowing(WindowExpr, NameMetadata),
265}
266
267#[derive(
268    Debug,
269    Clone,
270    PartialEq,
271    Eq,
272    PartialOrd,
273    Ord,
274    Hash,
275    Serialize,
276    Deserialize
277)]
278/// Represents the invocation of a window function over an optional partitioning with an optional
279/// order.
280pub struct WindowExpr {
281    pub func: WindowExprType,
282    pub partition_by: Vec<HirScalarExpr>,
283    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
284    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
285    ///    above,
286    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
287    /// These are separated because they are used in different places: the outer `order_by` is used
288    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
289    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
290    /// (`WindowExpr` exists only in HIR, but not in MIR.)
291    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
292    /// lowering, and not to original input columns.
293    pub order_by: Vec<HirScalarExpr>,
294}
295
296impl WindowExpr {
297    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
298    where
299        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
300    {
301        #[allow(deprecated)]
302        self.func.visit_expressions(f)?;
303        for expr in self.partition_by.iter() {
304            f(expr)?;
305        }
306        for expr in self.order_by.iter() {
307            f(expr)?;
308        }
309        Ok(())
310    }
311
312    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
313    where
314        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
315    {
316        #[allow(deprecated)]
317        self.func.visit_expressions_mut(f)?;
318        for expr in self.partition_by.iter_mut() {
319            f(expr)?;
320        }
321        for expr in self.order_by.iter_mut() {
322            f(expr)?;
323        }
324        Ok(())
325    }
326}
327
328/// Yields the scalars in `func`'s arguments, plus those in `partition_by`
329/// and `order_by`; does not descend into them.
330impl VisitChildren<HirScalarExpr> for WindowExpr {
331    fn visit_children<F>(&self, mut f: F)
332    where
333        F: FnMut(&HirScalarExpr),
334    {
335        self.func.visit_children(&mut f);
336        for expr in self.partition_by.iter() {
337            f(expr);
338        }
339        for expr in self.order_by.iter() {
340            f(expr);
341        }
342    }
343
344    fn visit_mut_children<F>(&mut self, mut f: F)
345    where
346        F: FnMut(&mut HirScalarExpr),
347    {
348        self.func.visit_mut_children(&mut f);
349        for expr in self.partition_by.iter_mut() {
350            f(expr);
351        }
352        for expr in self.order_by.iter_mut() {
353            f(expr);
354        }
355    }
356
357    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
358    where
359        F: FnMut(&HirScalarExpr) -> Result<(), E>,
360        E: From<RecursionLimitError>,
361    {
362        self.func.try_visit_children(&mut f)?;
363        for expr in self.partition_by.iter() {
364            f(expr)?;
365        }
366        for expr in self.order_by.iter() {
367            f(expr)?;
368        }
369        Ok(())
370    }
371
372    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
373    where
374        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
375        E: From<RecursionLimitError>,
376    {
377        self.func.try_visit_mut_children(&mut f)?;
378        for expr in self.partition_by.iter_mut() {
379            f(expr)?;
380        }
381        for expr in self.order_by.iter_mut() {
382            f(expr)?;
383        }
384        Ok(())
385    }
386}
387
388#[derive(
389    Debug,
390    Clone,
391    PartialEq,
392    Eq,
393    PartialOrd,
394    Ord,
395    Hash,
396    Serialize,
397    Deserialize
398)]
399/// A window function with its parameters.
400///
401/// There are three types of window functions:
402/// - scalar window functions, which return a different scalar value for each
403///   row within a partition that depends exclusively on the position of the row
404///   within the partition;
405/// - value window functions, which return a scalar value for each row within a
406///   partition that might be computed based on a single row, which is usually not
407///   the current row (e.g., previous or following row; first or last row of the
408///   partition);
409/// - aggregate window functions, which compute a traditional aggregation as a
410///   window function (e.g. `sum(x) OVER (...)`).
411///   (Aggregate window  functions can in some cases be computed by joining the
412///   input relation with a reduction over the same relation that computes the
413///   aggregation using the partition key as its grouping key, but we don't
414///   automatically do this currently.)
415pub enum WindowExprType {
416    Scalar(ScalarWindowExpr),
417    Value(ValueWindowExpr),
418    Aggregate(AggregateWindowExpr),
419}
420
421impl WindowExprType {
422    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
423    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
424    where
425        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
426    {
427        #[allow(deprecated)]
428        match self {
429            Self::Scalar(expr) => expr.visit_expressions(f),
430            Self::Value(expr) => expr.visit_expressions(f),
431            Self::Aggregate(expr) => expr.visit_expressions(f),
432        }
433    }
434
435    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
436    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
437    where
438        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
439    {
440        #[allow(deprecated)]
441        match self {
442            Self::Scalar(expr) => expr.visit_expressions_mut(f),
443            Self::Value(expr) => expr.visit_expressions_mut(f),
444            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
445        }
446    }
447
448    fn typ(
449        &self,
450        outers: &[SqlRelationType],
451        inner: &SqlRelationType,
452        params: &BTreeMap<usize, SqlScalarType>,
453    ) -> SqlColumnType {
454        match self {
455            Self::Scalar(expr) => expr.typ(outers, inner, params),
456            Self::Value(expr) => expr.typ(outers, inner, params),
457            Self::Aggregate(expr) => expr.typ(outers, inner, params),
458        }
459    }
460}
461
462/// Dispatches to the inner `Value` / `Aggregate` variant's scalar children;
463/// `Scalar` window functions have no scalar children.
464impl VisitChildren<HirScalarExpr> for WindowExprType {
465    fn visit_children<F>(&self, f: F)
466    where
467        F: FnMut(&HirScalarExpr),
468    {
469        match self {
470            Self::Scalar(_) => (),
471            Self::Value(expr) => expr.visit_children(f),
472            Self::Aggregate(expr) => expr.visit_children(f),
473        }
474    }
475
476    fn visit_mut_children<F>(&mut self, f: F)
477    where
478        F: FnMut(&mut HirScalarExpr),
479    {
480        match self {
481            Self::Scalar(_) => (),
482            Self::Value(expr) => expr.visit_mut_children(f),
483            Self::Aggregate(expr) => expr.visit_mut_children(f),
484        }
485    }
486
487    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
488    where
489        F: FnMut(&HirScalarExpr) -> Result<(), E>,
490        E: From<RecursionLimitError>,
491    {
492        match self {
493            Self::Scalar(_) => Ok(()),
494            Self::Value(expr) => expr.try_visit_children(f),
495            Self::Aggregate(expr) => expr.try_visit_children(f),
496        }
497    }
498
499    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
500    where
501        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
502        E: From<RecursionLimitError>,
503    {
504        match self {
505            Self::Scalar(_) => Ok(()),
506            Self::Value(expr) => expr.try_visit_mut_children(f),
507            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
508        }
509    }
510}
511
512#[derive(
513    Debug,
514    Clone,
515    PartialEq,
516    Eq,
517    PartialOrd,
518    Ord,
519    Hash,
520    Serialize,
521    Deserialize
522)]
523pub struct ScalarWindowExpr {
524    pub func: ScalarWindowFunc,
525    pub order_by: Vec<ColumnOrder>,
526}
527
528impl ScalarWindowExpr {
529    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
530    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
531    where
532        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
533    {
534        match self.func {
535            ScalarWindowFunc::RowNumber => {}
536            ScalarWindowFunc::Rank => {}
537            ScalarWindowFunc::DenseRank => {}
538        }
539        Ok(())
540    }
541
542    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
543    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
544    where
545        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
546    {
547        match self.func {
548            ScalarWindowFunc::RowNumber => {}
549            ScalarWindowFunc::Rank => {}
550            ScalarWindowFunc::DenseRank => {}
551        }
552        Ok(())
553    }
554
555    fn typ(
556        &self,
557        _outers: &[SqlRelationType],
558        _inner: &SqlRelationType,
559        _params: &BTreeMap<usize, SqlScalarType>,
560    ) -> SqlColumnType {
561        self.func.output_sql_type()
562    }
563
564    pub fn into_expr(self) -> mz_expr::AggregateFunc {
565        match self.func {
566            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
567                order_by: self.order_by,
568            },
569            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
570                order_by: self.order_by,
571            },
572            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
573                order_by: self.order_by,
574            },
575        }
576    }
577}
578
579#[derive(
580    Debug,
581    Clone,
582    PartialEq,
583    Eq,
584    PartialOrd,
585    Ord,
586    Hash,
587    Serialize,
588    Deserialize
589)]
590/// Scalar Window functions
591pub enum ScalarWindowFunc {
592    RowNumber,
593    Rank,
594    DenseRank,
595}
596
597impl Display for ScalarWindowFunc {
598    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
599        match self {
600            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
601            ScalarWindowFunc::Rank => write!(f, "rank"),
602            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
603        }
604    }
605}
606
607impl ScalarWindowFunc {
608    pub fn output_sql_type(&self) -> SqlColumnType {
609        match self {
610            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
611            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
612            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
613        }
614    }
615}
616
617#[derive(
618    Debug,
619    Clone,
620    PartialEq,
621    Eq,
622    PartialOrd,
623    Ord,
624    Hash,
625    Serialize,
626    Deserialize
627)]
628pub struct ValueWindowExpr {
629    pub func: ValueWindowFunc,
630    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
631    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
632    /// e.g., `row(#1, 3, null)`.
633    /// If it's a fused window function, then the arguments of each of the constituent function
634    /// calls are wrapped in an outer record.
635    pub args: Box<HirScalarExpr>,
636    /// See comment on `WindowExpr::order_by`.
637    pub order_by: Vec<ColumnOrder>,
638    pub window_frame: WindowFrame,
639    pub ignore_nulls: bool,
640}
641
642impl Display for ValueWindowFunc {
643    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
644        match self {
645            ValueWindowFunc::Lag => write!(f, "lag"),
646            ValueWindowFunc::Lead => write!(f, "lead"),
647            ValueWindowFunc::FirstValue => write!(f, "first_value"),
648            ValueWindowFunc::LastValue => write!(f, "last_value"),
649            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
650        }
651    }
652}
653
654impl ValueWindowExpr {
655    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
656    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
657    where
658        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
659    {
660        f(&self.args)
661    }
662
663    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
664    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
665    where
666        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
667    {
668        f(&mut self.args)
669    }
670
671    fn typ(
672        &self,
673        outers: &[SqlRelationType],
674        inner: &SqlRelationType,
675        params: &BTreeMap<usize, SqlScalarType>,
676    ) -> SqlColumnType {
677        self.func
678            .output_sql_type(self.args.typ(outers, inner, params))
679    }
680
681    /// Converts into `mz_expr::AggregateFunc`.
682    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
683        (
684            self.args,
685            self.func
686                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
687        )
688    }
689}
690
691/// Yields `args` (the value-window function's argument expression);
692/// does not descend into it.
693impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
694    fn visit_children<F>(&self, mut f: F)
695    where
696        F: FnMut(&HirScalarExpr),
697    {
698        f(&self.args)
699    }
700
701    fn visit_mut_children<F>(&mut self, mut f: F)
702    where
703        F: FnMut(&mut HirScalarExpr),
704    {
705        f(&mut self.args)
706    }
707
708    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
709    where
710        F: FnMut(&HirScalarExpr) -> Result<(), E>,
711        E: From<RecursionLimitError>,
712    {
713        f(&self.args)
714    }
715
716    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
717    where
718        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
719        E: From<RecursionLimitError>,
720    {
721        f(&mut self.args)
722    }
723}
724
725#[derive(
726    Debug,
727    Clone,
728    PartialEq,
729    Eq,
730    PartialOrd,
731    Ord,
732    Hash,
733    Serialize,
734    Deserialize
735)]
736/// Value Window functions
737pub enum ValueWindowFunc {
738    Lag,
739    Lead,
740    FirstValue,
741    LastValue,
742    Fused(Vec<ValueWindowFunc>),
743}
744
745impl ValueWindowFunc {
746    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
747        match self {
748            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
749                // The input is a (value, offset, default) record, so extract the type of the first arg
750                input_type.scalar_type.unwrap_record_element_type()[0]
751                    .clone()
752                    .nullable(true)
753            }
754            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
755                input_type.scalar_type.nullable(true)
756            }
757            ValueWindowFunc::Fused(funcs) => {
758                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
759                SqlScalarType::Record {
760                    fields: funcs
761                        .iter()
762                        .zip_eq(input_types)
763                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
764                        .collect(),
765                    custom_id: None,
766                }
767                .nullable(false)
768            }
769        }
770    }
771
772    pub fn into_expr(
773        self,
774        order_by: Vec<ColumnOrder>,
775        window_frame: WindowFrame,
776        ignore_nulls: bool,
777    ) -> mz_expr::AggregateFunc {
778        match self {
779            // Lag and Lead are fundamentally the same function, just with opposite directions
780            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
781                order_by,
782                lag_lead: mz_expr::LagLeadType::Lag,
783                ignore_nulls,
784            },
785            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
786                order_by,
787                lag_lead: mz_expr::LagLeadType::Lead,
788                ignore_nulls,
789            },
790            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
791                order_by,
792                window_frame,
793            },
794            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
795                order_by,
796                window_frame,
797            },
798            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
799                funcs: funcs
800                    .into_iter()
801                    .map(|func| {
802                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
803                    })
804                    .collect(),
805                order_by,
806            },
807        }
808    }
809}
810
811#[derive(
812    Debug,
813    Clone,
814    PartialEq,
815    Eq,
816    PartialOrd,
817    Ord,
818    Hash,
819    Serialize,
820    Deserialize
821)]
822pub struct AggregateWindowExpr {
823    pub aggregate_expr: AggregateExpr,
824    pub order_by: Vec<ColumnOrder>,
825    pub window_frame: WindowFrame,
826}
827
828impl AggregateWindowExpr {
829    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
830    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
831    where
832        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
833    {
834        f(&self.aggregate_expr.expr)
835    }
836
837    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
838    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
839    where
840        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
841    {
842        f(&mut self.aggregate_expr.expr)
843    }
844
845    fn typ(
846        &self,
847        outers: &[SqlRelationType],
848        inner: &SqlRelationType,
849        params: &BTreeMap<usize, SqlScalarType>,
850    ) -> SqlColumnType {
851        self.aggregate_expr
852            .func
853            .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params))
854    }
855
856    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
857        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
858            (
859                self.aggregate_expr.expr,
860                FusedWindowAggregate {
861                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
862                    order_by: self.order_by,
863                    window_frame: self.window_frame,
864                },
865            )
866        } else {
867            (
868                self.aggregate_expr.expr,
869                WindowAggregate {
870                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
871                    order_by: self.order_by,
872                    window_frame: self.window_frame,
873                },
874            )
875        }
876    }
877}
878
879/// Yields the aggregate's argument expression (`aggregate_expr.expr`);
880/// does not descend into it.
881impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
882    fn visit_children<F>(&self, mut f: F)
883    where
884        F: FnMut(&HirScalarExpr),
885    {
886        f(&self.aggregate_expr.expr)
887    }
888
889    fn visit_mut_children<F>(&mut self, mut f: F)
890    where
891        F: FnMut(&mut HirScalarExpr),
892    {
893        f(&mut self.aggregate_expr.expr)
894    }
895
896    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
897    where
898        F: FnMut(&HirScalarExpr) -> Result<(), E>,
899        E: From<RecursionLimitError>,
900    {
901        f(&self.aggregate_expr.expr)
902    }
903
904    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
905    where
906        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
907        E: From<RecursionLimitError>,
908    {
909        f(&mut self.aggregate_expr.expr)
910    }
911}
912
913/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
914/// determined. Several SQL expressions can be freely coerced based upon where
915/// in the expression tree they appear. For example, the string literal '42'
916/// will be automatically coerced to the integer 42 if used in a numeric
917/// context:
918///
919/// ```sql
920/// SELECT '42' + 42
921/// ```
922///
923/// This separate type gives the code that needs to interact with coercions very
924/// fine-grained control over what coercions happen and when.
925///
926/// The primary driver of coercion is function and operator selection, as
927/// choosing the correct function or operator implementation depends on the type
928/// of the provided arguments. Coercion also occurs at the very root of the
929/// scalar expression tree. For example in
930///
931/// ```sql
932/// SELECT ... WHERE $1
933/// ```
934///
935/// the `WHERE` clause will coerce the contained unconstrained type parameter
936/// `$1` to have type bool.
937#[derive(Clone, Debug)]
938pub enum CoercibleScalarExpr {
939    Coerced(HirScalarExpr),
940    Parameter(usize),
941    LiteralNull,
942    LiteralString(String),
943    LiteralRecord(Vec<CoercibleScalarExpr>),
944}
945
946impl CoercibleScalarExpr {
947    pub fn type_as(
948        self,
949        ecx: &ExprContext,
950        ty: &SqlScalarType,
951    ) -> Result<HirScalarExpr, PlanError> {
952        let expr = typeconv::plan_coerce(ecx, self, ty)?;
953        let expr_ty = ecx.scalar_type(&expr);
954        if ty != &expr_ty {
955            sql_bail!(
956                "{} must have type {}, not type {}",
957                ecx.name,
958                ecx.humanize_sql_scalar_type(ty, false),
959                ecx.humanize_sql_scalar_type(&expr_ty, false),
960            );
961        }
962        Ok(expr)
963    }
964
965    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
966        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
967    }
968
969    pub fn cast_to(
970        self,
971        ecx: &ExprContext,
972        ccx: CastContext,
973        ty: &SqlScalarType,
974    ) -> Result<HirScalarExpr, PlanError> {
975        let expr = typeconv::plan_coerce(ecx, self, ty)?;
976        typeconv::plan_cast(ecx, ccx, expr, ty)
977    }
978}
979
980/// The column type for a [`CoercibleScalarExpr`].
981#[derive(Clone, Debug)]
982pub enum CoercibleColumnType {
983    Coerced(SqlColumnType),
984    Record(Vec<CoercibleColumnType>),
985    Uncoerced,
986}
987
988impl CoercibleColumnType {
989    /// Reports the nullability of the type.
990    pub fn nullable(&self) -> bool {
991        match self {
992            // A coerced value's nullability is known.
993            CoercibleColumnType::Coerced(ct) => ct.nullable,
994
995            // A literal record can never be null.
996            CoercibleColumnType::Record(_) => false,
997
998            // An uncoerced literal may be the literal `NULL`, so we have
999            // to conservatively assume it is nullable.
1000            CoercibleColumnType::Uncoerced => true,
1001        }
1002    }
1003}
1004
1005/// The scalar type for a [`CoercibleScalarExpr`].
1006#[derive(Clone, Debug)]
1007pub enum CoercibleScalarType {
1008    Coerced(SqlScalarType),
1009    Record(Vec<CoercibleColumnType>),
1010    Uncoerced,
1011}
1012
1013impl CoercibleScalarType {
1014    /// Reports whether the scalar type has been coerced.
1015    pub fn is_coerced(&self) -> bool {
1016        matches!(self, CoercibleScalarType::Coerced(_))
1017    }
1018
1019    /// Returns the coerced scalar type, if the type is coerced.
1020    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1021        match self {
1022            CoercibleScalarType::Coerced(t) => Some(t),
1023            _ => None,
1024        }
1025    }
1026
1027    /// If the type is coerced, apply the mapping function to the contained
1028    /// scalar type.
1029    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1030    where
1031        F: FnOnce(SqlScalarType) -> SqlScalarType,
1032    {
1033        match self {
1034            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1035            _ => self,
1036        }
1037    }
1038
1039    /// If the type is an coercible record, forcibly converts to a coerced
1040    /// record type. Any uncoerced field types are assumed to be of type text.
1041    ///
1042    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
1043    /// accepts a type hint that can indicate the types of uncoerced field
1044    /// types.
1045    pub fn force_coerced_if_record(&mut self) {
1046        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1047            let mut fields = vec![];
1048            for (i, uf) in uncoerced_fields.enumerate() {
1049                let name = ColumnName::from(format!("f{}", i + 1));
1050                let ty = match uf {
1051                    CoercibleColumnType::Coerced(ty) => ty,
1052                    CoercibleColumnType::Record(mut fields) => {
1053                        convert(fields.drain(..)).nullable(false)
1054                    }
1055                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1056                };
1057                fields.push((name, ty))
1058            }
1059            SqlScalarType::Record {
1060                fields: fields.into(),
1061                custom_id: None,
1062            }
1063        }
1064
1065        if let CoercibleScalarType::Record(fields) = self {
1066            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1067        }
1068    }
1069}
1070
1071/// An expression whose type can be ascertained.
1072///
1073/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
1074pub trait AbstractExpr {
1075    type Type: AbstractColumnType;
1076
1077    /// Computes the type of the expression.
1078    fn typ(
1079        &self,
1080        outers: &[SqlRelationType],
1081        inner: &SqlRelationType,
1082        params: &BTreeMap<usize, SqlScalarType>,
1083    ) -> Self::Type;
1084}
1085
1086impl AbstractExpr for CoercibleScalarExpr {
1087    type Type = CoercibleColumnType;
1088
1089    fn typ(
1090        &self,
1091        outers: &[SqlRelationType],
1092        inner: &SqlRelationType,
1093        params: &BTreeMap<usize, SqlScalarType>,
1094    ) -> Self::Type {
1095        match self {
1096            CoercibleScalarExpr::Coerced(expr) => {
1097                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1098            }
1099            CoercibleScalarExpr::LiteralRecord(scalars) => {
1100                let fields = scalars
1101                    .iter()
1102                    .map(|s| s.typ(outers, inner, params))
1103                    .collect();
1104                CoercibleColumnType::Record(fields)
1105            }
1106            _ => CoercibleColumnType::Uncoerced,
1107        }
1108    }
1109}
1110
1111/// A column type-like object whose underlying scalar type-like object can be
1112/// ascertained.
1113///
1114/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1115pub trait AbstractColumnType {
1116    type AbstractScalarType;
1117
1118    /// Converts the column type-like object into its inner scalar type-like
1119    /// object.
1120    fn scalar_type(self) -> Self::AbstractScalarType;
1121}
1122
1123impl AbstractColumnType for SqlColumnType {
1124    type AbstractScalarType = SqlScalarType;
1125
1126    fn scalar_type(self) -> Self::AbstractScalarType {
1127        self.scalar_type
1128    }
1129}
1130
1131impl AbstractColumnType for CoercibleColumnType {
1132    type AbstractScalarType = CoercibleScalarType;
1133
1134    fn scalar_type(self) -> Self::AbstractScalarType {
1135        match self {
1136            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1137            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1138            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1139        }
1140    }
1141}
1142
1143impl From<HirScalarExpr> for CoercibleScalarExpr {
1144    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1145        CoercibleScalarExpr::Coerced(expr)
1146    }
1147}
1148
1149/// A leveled column reference.
1150///
1151/// In the course of decorrelation, multiple levels of nested subqueries are
1152/// traversed, and references to columns may correspond to different levels
1153/// of containing outer subqueries.
1154///
1155/// A `ColumnRef` allows expressions to refer to columns while being clear
1156/// about which level the column references without manually performing the
1157/// bookkeeping tracking their actual column locations.
1158///
1159/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1160/// from the reference, using `column` as a unique identifier in that subquery level.
1161/// A `level` of zero corresponds to the current scope, and levels increase to
1162/// indicate subqueries further "outwards".
1163#[derive(
1164    Debug,
1165    Clone,
1166    Copy,
1167    PartialEq,
1168    Eq,
1169    Hash,
1170    Ord,
1171    PartialOrd,
1172    Serialize,
1173    Deserialize
1174)]
1175pub struct ColumnRef {
1176    // scope level, where 0 is the current scope and 1+ are outer scopes.
1177    pub level: usize,
1178    // level-local column identifier used.
1179    pub column: usize,
1180}
1181
1182#[derive(
1183    Debug,
1184    Clone,
1185    PartialEq,
1186    Eq,
1187    PartialOrd,
1188    Ord,
1189    Hash,
1190    Serialize,
1191    Deserialize
1192)]
1193pub enum JoinKind {
1194    Inner,
1195    LeftOuter,
1196    RightOuter,
1197    FullOuter,
1198}
1199
1200impl fmt::Display for JoinKind {
1201    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1202        write!(
1203            f,
1204            "{}",
1205            match self {
1206                JoinKind::Inner => "Inner",
1207                JoinKind::LeftOuter => "LeftOuter",
1208                JoinKind::RightOuter => "RightOuter",
1209                JoinKind::FullOuter => "FullOuter",
1210            }
1211        )
1212    }
1213}
1214
1215impl JoinKind {
1216    pub fn can_be_correlated(&self) -> bool {
1217        match self {
1218            JoinKind::Inner | JoinKind::LeftOuter => true,
1219            JoinKind::RightOuter | JoinKind::FullOuter => false,
1220        }
1221    }
1222
1223    pub fn can_elide_identity_left_join(&self) -> bool {
1224        match self {
1225            JoinKind::Inner | JoinKind::RightOuter => true,
1226            JoinKind::LeftOuter | JoinKind::FullOuter => false,
1227        }
1228    }
1229
1230    pub fn can_elide_identity_right_join(&self) -> bool {
1231        match self {
1232            JoinKind::Inner | JoinKind::LeftOuter => true,
1233            JoinKind::RightOuter | JoinKind::FullOuter => false,
1234        }
1235    }
1236}
1237
1238#[derive(
1239    Debug,
1240    Clone,
1241    PartialEq,
1242    Eq,
1243    PartialOrd,
1244    Ord,
1245    Hash,
1246    Serialize,
1247    Deserialize
1248)]
1249pub struct AggregateExpr {
1250    pub func: AggregateFunc,
1251    pub expr: Box<HirScalarExpr>,
1252    pub distinct: bool,
1253}
1254
1255/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1256/// types may be different.
1257///
1258/// Specifically, the nullability of the aggregate columns is more common
1259/// here than in `expr`, as these aggregates may be applied over empty
1260/// result sets and should be null in those cases, whereas `expr` variants
1261/// only return null values when supplied nulls as input.
1262#[derive(
1263    Clone,
1264    Debug,
1265    Eq,
1266    PartialEq,
1267    PartialOrd,
1268    Ord,
1269    Hash,
1270    Serialize,
1271    Deserialize
1272)]
1273pub enum AggregateFunc {
1274    MaxNumeric,
1275    MaxInt16,
1276    MaxInt32,
1277    MaxInt64,
1278    MaxUInt16,
1279    MaxUInt32,
1280    MaxUInt64,
1281    MaxMzTimestamp,
1282    MaxFloat32,
1283    MaxFloat64,
1284    MaxBool,
1285    MaxString,
1286    MaxDate,
1287    MaxTimestamp,
1288    MaxTimestampTz,
1289    MaxInterval,
1290    MaxTime,
1291    MinNumeric,
1292    MinInt16,
1293    MinInt32,
1294    MinInt64,
1295    MinUInt16,
1296    MinUInt32,
1297    MinUInt64,
1298    MinMzTimestamp,
1299    MinFloat32,
1300    MinFloat64,
1301    MinBool,
1302    MinString,
1303    MinDate,
1304    MinTimestamp,
1305    MinTimestampTz,
1306    MinInterval,
1307    MinTime,
1308    SumInt16,
1309    SumInt32,
1310    SumInt64,
1311    SumUInt16,
1312    SumUInt32,
1313    SumUInt64,
1314    SumFloat32,
1315    SumFloat64,
1316    SumNumeric,
1317    Count,
1318    Any,
1319    All,
1320    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1321    /// into a JSON list. The other elements are columns used by `order_by`.
1322    ///
1323    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1324    /// layer, this function filters out `Datum::Null`, for consistency with
1325    /// the other aggregate functions.
1326    JsonbAgg {
1327        order_by: Vec<ColumnOrder>,
1328    },
1329    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1330    /// JSON map. The other elements are columns used by `order_by`.
1331    JsonbObjectAgg {
1332        order_by: Vec<ColumnOrder>,
1333    },
1334    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1335    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1336    /// elements are columns used by `order_by`.
1337    MapAgg {
1338        order_by: Vec<ColumnOrder>,
1339        value_type: SqlScalarType,
1340    },
1341    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1342    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1343    ArrayConcat {
1344        order_by: Vec<ColumnOrder>,
1345    },
1346    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1347    /// single `Datum::List`. The other elements are columns used by `order_by`.
1348    ListConcat {
1349        order_by: Vec<ColumnOrder>,
1350    },
1351    StringAgg {
1352        order_by: Vec<ColumnOrder>,
1353    },
1354    /// A bundle of fused window aggregations: its input is a record, whose each
1355    /// component will be the input to one of the `AggregateFunc`s.
1356    ///
1357    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1358    /// more specifically an `AggregateWindowExpr`.
1359    FusedWindowAgg {
1360        funcs: Vec<AggregateFunc>,
1361    },
1362    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1363    ///
1364    /// Useful for removing an expensive aggregation while maintaining the shape
1365    /// of a reduce operator.
1366    Dummy,
1367}
1368
1369impl AggregateFunc {
1370    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1371    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1372        match self {
1373            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1374            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1375            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1376            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1377            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1378            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1379            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1380            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1381            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1382            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1383            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1384            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1385            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1386            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1387            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1388            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1389            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1390            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1391            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1392            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1393            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1394            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1395            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1396            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1397            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1398            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1399            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1400            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1401            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1402            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1403            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1404            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1405            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1406            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1407            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1408            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1409            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1410            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1411            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1412            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1413            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1414            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1415            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1416            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1417            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1418            AggregateFunc::All => mz_expr::AggregateFunc::All,
1419            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1420            AggregateFunc::JsonbObjectAgg { order_by } => {
1421                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1422            }
1423            AggregateFunc::MapAgg {
1424                order_by,
1425                value_type,
1426            } => mz_expr::AggregateFunc::MapAgg {
1427                order_by,
1428                value_type,
1429            },
1430            AggregateFunc::ArrayConcat { order_by } => {
1431                mz_expr::AggregateFunc::ArrayConcat { order_by }
1432            }
1433            AggregateFunc::ListConcat { order_by } => {
1434                mz_expr::AggregateFunc::ListConcat { order_by }
1435            }
1436            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1437            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1438            // `AggregateWindowExpr::into_expr`.
1439            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1440                panic!("into_expr called on FusedWindowAgg")
1441            }
1442            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1443        }
1444    }
1445
1446    /// Returns a datum whose inclusion in the aggregation will not change its
1447    /// result.
1448    ///
1449    /// # Panics
1450    ///
1451    /// Panics if called on a `FusedWindowAgg`.
1452    pub fn identity_datum(&self) -> Datum<'static> {
1453        match self {
1454            AggregateFunc::Any => Datum::False,
1455            AggregateFunc::All => Datum::True,
1456            AggregateFunc::Dummy => Datum::Dummy,
1457            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1458            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1459            AggregateFunc::MaxNumeric
1460            | AggregateFunc::MaxInt16
1461            | AggregateFunc::MaxInt32
1462            | AggregateFunc::MaxInt64
1463            | AggregateFunc::MaxUInt16
1464            | AggregateFunc::MaxUInt32
1465            | AggregateFunc::MaxUInt64
1466            | AggregateFunc::MaxMzTimestamp
1467            | AggregateFunc::MaxFloat32
1468            | AggregateFunc::MaxFloat64
1469            | AggregateFunc::MaxBool
1470            | AggregateFunc::MaxString
1471            | AggregateFunc::MaxDate
1472            | AggregateFunc::MaxTimestamp
1473            | AggregateFunc::MaxTimestampTz
1474            | AggregateFunc::MaxInterval
1475            | AggregateFunc::MaxTime
1476            | AggregateFunc::MinNumeric
1477            | AggregateFunc::MinInt16
1478            | AggregateFunc::MinInt32
1479            | AggregateFunc::MinInt64
1480            | AggregateFunc::MinUInt16
1481            | AggregateFunc::MinUInt32
1482            | AggregateFunc::MinUInt64
1483            | AggregateFunc::MinMzTimestamp
1484            | AggregateFunc::MinFloat32
1485            | AggregateFunc::MinFloat64
1486            | AggregateFunc::MinBool
1487            | AggregateFunc::MinString
1488            | AggregateFunc::MinDate
1489            | AggregateFunc::MinTimestamp
1490            | AggregateFunc::MinTimestampTz
1491            | AggregateFunc::MinInterval
1492            | AggregateFunc::MinTime
1493            | AggregateFunc::SumInt16
1494            | AggregateFunc::SumInt32
1495            | AggregateFunc::SumInt64
1496            | AggregateFunc::SumUInt16
1497            | AggregateFunc::SumUInt32
1498            | AggregateFunc::SumUInt64
1499            | AggregateFunc::SumFloat32
1500            | AggregateFunc::SumFloat64
1501            | AggregateFunc::SumNumeric
1502            | AggregateFunc::Count
1503            | AggregateFunc::JsonbAgg { .. }
1504            | AggregateFunc::JsonbObjectAgg { .. }
1505            | AggregateFunc::MapAgg { .. }
1506            | AggregateFunc::StringAgg { .. } => Datum::Null,
1507            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1508                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1509                // in HIR planning, because it is introduced only during HIR transformation.
1510                //
1511                // The implementation could be something like the following, except that we need to
1512                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1513                // ```
1514                // let temp_storage = RowArena::new();
1515                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1516                // ```
1517                panic!("FusedWindowAgg doesn't have an identity_datum")
1518            }
1519        }
1520    }
1521
1522    /// The output column type for the result of an aggregation.
1523    ///
1524    /// The output column type also contains nullability information, which
1525    /// is (without further information) true for aggregations that are not
1526    /// counts.
1527    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1528        let scalar_type = match self {
1529            AggregateFunc::Count => SqlScalarType::Int64,
1530            AggregateFunc::Any => SqlScalarType::Bool,
1531            AggregateFunc::All => SqlScalarType::Bool,
1532            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1533            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1534            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1535            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1536            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1537                max_scale: Some(NumericMaxScale::ZERO),
1538            },
1539            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1540            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1541                max_scale: Some(NumericMaxScale::ZERO),
1542            },
1543            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1544                value_type: Box::new(value_type.clone()),
1545                custom_id: None,
1546            },
1547            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1548                match input_type.scalar_type {
1549                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1550                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1551                    _ => unreachable!(),
1552                }
1553            }
1554            AggregateFunc::MaxNumeric
1555            | AggregateFunc::MaxInt16
1556            | AggregateFunc::MaxInt32
1557            | AggregateFunc::MaxInt64
1558            | AggregateFunc::MaxUInt16
1559            | AggregateFunc::MaxUInt32
1560            | AggregateFunc::MaxUInt64
1561            | AggregateFunc::MaxMzTimestamp
1562            | AggregateFunc::MaxFloat32
1563            | AggregateFunc::MaxFloat64
1564            | AggregateFunc::MaxBool
1565            | AggregateFunc::MaxString
1566            | AggregateFunc::MaxDate
1567            | AggregateFunc::MaxTimestamp
1568            | AggregateFunc::MaxTimestampTz
1569            | AggregateFunc::MaxInterval
1570            | AggregateFunc::MaxTime
1571            | AggregateFunc::MinNumeric
1572            | AggregateFunc::MinInt16
1573            | AggregateFunc::MinInt32
1574            | AggregateFunc::MinInt64
1575            | AggregateFunc::MinUInt16
1576            | AggregateFunc::MinUInt32
1577            | AggregateFunc::MinUInt64
1578            | AggregateFunc::MinMzTimestamp
1579            | AggregateFunc::MinFloat32
1580            | AggregateFunc::MinFloat64
1581            | AggregateFunc::MinBool
1582            | AggregateFunc::MinString
1583            | AggregateFunc::MinDate
1584            | AggregateFunc::MinTimestamp
1585            | AggregateFunc::MinTimestampTz
1586            | AggregateFunc::MinInterval
1587            | AggregateFunc::MinTime
1588            | AggregateFunc::SumFloat32
1589            | AggregateFunc::SumFloat64
1590            | AggregateFunc::SumNumeric
1591            | AggregateFunc::Dummy => input_type.scalar_type,
1592            AggregateFunc::FusedWindowAgg { funcs } => {
1593                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1594                SqlScalarType::Record {
1595                    fields: funcs
1596                        .iter()
1597                        .zip_eq(input_types)
1598                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
1599                        .collect(),
1600                    custom_id: None,
1601                }
1602            }
1603        };
1604        // max/min/sum return null on empty sets
1605        let nullable = !matches!(self, AggregateFunc::Count);
1606        scalar_type.nullable(nullable)
1607    }
1608
1609    pub fn is_order_sensitive(&self) -> bool {
1610        use AggregateFunc::*;
1611        matches!(
1612            self,
1613            JsonbAgg { .. }
1614                | JsonbObjectAgg { .. }
1615                | MapAgg { .. }
1616                | ArrayConcat { .. }
1617                | ListConcat { .. }
1618                | StringAgg { .. }
1619        )
1620    }
1621}
1622
1623impl HirRelationExpr {
1624    /// Gets the SQL type of a self-contained, top-level expression.
1625    pub fn top_level_typ(&self) -> SqlRelationType {
1626        self.typ(&[], &BTreeMap::new())
1627    }
1628
1629    /// Gets the SQL type of the expression.
1630    ///
1631    /// `outers` gives types for outer relations.
1632    /// `params` gives types for parameters.
1633    pub fn typ(
1634        &self,
1635        outers: &[SqlRelationType],
1636        params: &BTreeMap<usize, SqlScalarType>,
1637    ) -> SqlRelationType {
1638        stack::maybe_grow(|| match self {
1639            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1640            HirRelationExpr::Get { typ, .. } => typ.clone(),
1641            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1642            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1643            HirRelationExpr::Project { input, outputs } => {
1644                let input_typ = input.typ(outers, params);
1645                SqlRelationType::new(
1646                    outputs
1647                        .iter()
1648                        .map(|&i| input_typ.column_types[i].clone())
1649                        .collect(),
1650                )
1651            }
1652            HirRelationExpr::Map { input, scalars } => {
1653                let mut typ = input.typ(outers, params);
1654                for scalar in scalars {
1655                    typ.column_types.push(scalar.typ(outers, &typ, params));
1656                }
1657                typ
1658            }
1659            HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(),
1660            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1661                input.typ(outers, params)
1662            }
1663            HirRelationExpr::Join {
1664                left, right, kind, ..
1665            } => {
1666                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1667                let right_nullable =
1668                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1669                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1670                    let nullable = t.nullable || left_nullable;
1671                    t.nullable(nullable)
1672                });
1673                let mut outers = outers.to_vec();
1674                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1675                let rt = right
1676                    .typ(&outers, params)
1677                    .column_types
1678                    .into_iter()
1679                    .map(|t| {
1680                        let nullable = t.nullable || right_nullable;
1681                        t.nullable(nullable)
1682                    });
1683                SqlRelationType::new(lt.chain(rt).collect())
1684            }
1685            HirRelationExpr::Reduce {
1686                input,
1687                group_key,
1688                aggregates,
1689                expected_group_size: _,
1690            } => {
1691                let input_typ = input.typ(outers, params);
1692                let mut column_types = group_key
1693                    .iter()
1694                    .map(|&i| input_typ.column_types[i].clone())
1695                    .collect::<Vec<_>>();
1696                for agg in aggregates {
1697                    column_types.push(agg.typ(outers, &input_typ, params));
1698                }
1699                // TODO(frank): add primary key information.
1700                SqlRelationType::new(column_types)
1701            }
1702            // TODO(frank): check for removal; add primary key information.
1703            HirRelationExpr::Distinct { input }
1704            | HirRelationExpr::Negate { input }
1705            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1706            HirRelationExpr::Union { base, inputs } => {
1707                let mut base_cols = base.typ(outers, params).column_types;
1708                for input in inputs {
1709                    for (base_col, col) in base_cols
1710                        .iter_mut()
1711                        .zip_eq(input.typ(outers, params).column_types)
1712                    {
1713                        *base_col = base_col.sql_union(&col).unwrap(); // HIR deliberately not using `union`
1714                    }
1715                }
1716                SqlRelationType::new(base_cols)
1717            }
1718        })
1719    }
1720
1721    pub fn arity(&self) -> usize {
1722        match self {
1723            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1724            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1725            HirRelationExpr::Let { body, .. } => body.arity(),
1726            HirRelationExpr::LetRec { body, .. } => body.arity(),
1727            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1728            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1729            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1730            HirRelationExpr::Filter { input, .. }
1731            | HirRelationExpr::TopK { input, .. }
1732            | HirRelationExpr::Distinct { input }
1733            | HirRelationExpr::Negate { input }
1734            | HirRelationExpr::Threshold { input } => input.arity(),
1735            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1736            HirRelationExpr::Union { base, .. } => base.arity(),
1737            HirRelationExpr::Reduce {
1738                group_key,
1739                aggregates,
1740                ..
1741            } => group_key.len() + aggregates.len(),
1742        }
1743    }
1744
1745    /// If self is a constant, return the value and the type, otherwise `None`.
1746    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1747        match self {
1748            Self::Constant { rows, typ } => Some((rows, typ)),
1749            _ => None,
1750        }
1751    }
1752
1753    /// Reports whether this expression contains a column reference to its
1754    /// direct parent scope.
1755    pub fn is_correlated(&self) -> bool {
1756        let mut correlated = false;
1757        #[allow(deprecated)]
1758        self.visit_columns(0, &mut |depth, col| {
1759            if col.level > depth && col.level - depth == 1 {
1760                correlated = true;
1761            }
1762        });
1763        correlated
1764    }
1765
1766    pub fn is_join_identity(&self) -> bool {
1767        match self {
1768            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1769            _ => false,
1770        }
1771    }
1772
1773    pub fn project(self, outputs: Vec<usize>) -> Self {
1774        if outputs.iter().copied().eq(0..self.arity()) {
1775            // The projection is trivial. Suppress it.
1776            self
1777        } else {
1778            HirRelationExpr::Project {
1779                input: Box::new(self),
1780                outputs,
1781            }
1782        }
1783    }
1784
1785    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1786        if scalars.is_empty() {
1787            // The map is trivial. Suppress it.
1788            self
1789        } else if let HirRelationExpr::Map {
1790            scalars: old_scalars,
1791            input: _,
1792        } = &mut self
1793        {
1794            // Map applied to a map. Fuse the maps.
1795            old_scalars.extend(scalars);
1796            self
1797        } else {
1798            HirRelationExpr::Map {
1799                input: Box::new(self),
1800                scalars,
1801            }
1802        }
1803    }
1804
1805    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1806        if let HirRelationExpr::Filter {
1807            input: _,
1808            predicates,
1809        } = &mut self
1810        {
1811            predicates.extend(preds);
1812            predicates.sort();
1813            predicates.dedup();
1814            self
1815        } else {
1816            preds.sort();
1817            preds.dedup();
1818            HirRelationExpr::Filter {
1819                input: Box::new(self),
1820                predicates: preds,
1821            }
1822        }
1823    }
1824
1825    pub fn reduce(
1826        self,
1827        group_key: Vec<usize>,
1828        aggregates: Vec<AggregateExpr>,
1829        expected_group_size: Option<u64>,
1830    ) -> Self {
1831        HirRelationExpr::Reduce {
1832            input: Box::new(self),
1833            group_key,
1834            aggregates,
1835            expected_group_size,
1836        }
1837    }
1838
1839    pub fn top_k(
1840        self,
1841        group_key: Vec<usize>,
1842        order_key: Vec<ColumnOrder>,
1843        limit: Option<HirScalarExpr>,
1844        offset: HirScalarExpr,
1845        expected_group_size: Option<u64>,
1846    ) -> Self {
1847        HirRelationExpr::TopK {
1848            input: Box::new(self),
1849            group_key,
1850            order_key,
1851            limit,
1852            offset,
1853            expected_group_size,
1854        }
1855    }
1856
1857    pub fn negate(self) -> Self {
1858        if let HirRelationExpr::Negate { input } = self {
1859            *input
1860        } else {
1861            HirRelationExpr::Negate {
1862                input: Box::new(self),
1863            }
1864        }
1865    }
1866
1867    pub fn distinct(self) -> Self {
1868        if let HirRelationExpr::Distinct { .. } = self {
1869            self
1870        } else {
1871            HirRelationExpr::Distinct {
1872                input: Box::new(self),
1873            }
1874        }
1875    }
1876
1877    pub fn threshold(self) -> Self {
1878        if let HirRelationExpr::Threshold { .. } = self {
1879            self
1880        } else {
1881            HirRelationExpr::Threshold {
1882                input: Box::new(self),
1883            }
1884        }
1885    }
1886
1887    pub fn union(self, other: Self) -> Self {
1888        let mut terms = Vec::new();
1889        if let HirRelationExpr::Union { base, inputs } = self {
1890            terms.push(*base);
1891            terms.extend(inputs);
1892        } else {
1893            terms.push(self);
1894        }
1895        if let HirRelationExpr::Union { base, inputs } = other {
1896            terms.push(*base);
1897            terms.extend(inputs);
1898        } else {
1899            terms.push(other);
1900        }
1901        HirRelationExpr::Union {
1902            base: Box::new(terms.remove(0)),
1903            inputs: terms,
1904        }
1905    }
1906
1907    pub fn exists(self) -> HirScalarExpr {
1908        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1909    }
1910
1911    pub fn select(self) -> HirScalarExpr {
1912        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1913    }
1914
1915    pub fn join(
1916        self,
1917        mut right: HirRelationExpr,
1918        on: HirScalarExpr,
1919        kind: JoinKind,
1920    ) -> HirRelationExpr {
1921        if self.is_join_identity()
1922            && !right.is_correlated()
1923            && on == HirScalarExpr::literal_true()
1924            && kind.can_elide_identity_left_join()
1925        {
1926            // The join can be elided, but we need to adjust column references
1927            // on the right-hand side to account for the removal of the scope
1928            // introduced by the join.
1929            #[allow(deprecated)]
1930            right.visit_columns_mut(0, &mut |depth, col| {
1931                if col.level > depth {
1932                    col.level -= 1;
1933                }
1934            });
1935            right
1936        } else if right.is_join_identity()
1937            && on == HirScalarExpr::literal_true()
1938            && kind.can_elide_identity_right_join()
1939        {
1940            self
1941        } else {
1942            HirRelationExpr::Join {
1943                left: Box::new(self),
1944                right: Box::new(right),
1945                on,
1946                kind,
1947            }
1948        }
1949    }
1950
1951    pub fn take(&mut self) -> HirRelationExpr {
1952        mem::replace(
1953            self,
1954            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1955        )
1956    }
1957
1958    #[deprecated = "Use `Visit::visit_post`."]
1959    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1960    where
1961        F: FnMut(&'a Self, usize),
1962    {
1963        #[allow(deprecated)]
1964        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1965                                                 depth: usize|
1966         -> Result<(), ()> {
1967            f(e, depth);
1968            Ok(())
1969        });
1970    }
1971
1972    #[deprecated = "Use `Visit::try_visit_post`."]
1973    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1974    where
1975        F: FnMut(&'a Self, usize) -> Result<(), E>,
1976    {
1977        #[allow(deprecated)]
1978        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1979            e.visit_fallible(depth, f)
1980        })?;
1981        f(self, depth)
1982    }
1983
1984    /// WARNING: `VisitChildren<HirRelationExpr>::try_visit_children` is NOT a
1985    /// drop-in replacement: in addition to the relation children visited here,
1986    /// it also descends into every subquery (`Exists`/`Select`) reachable
1987    /// through `self`'s scalar children — at any depth of scalar nesting.
1988    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1989    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1990    where
1991        F: FnMut(&'a Self, usize) -> Result<(), E>,
1992    {
1993        match self {
1994            HirRelationExpr::Constant { .. }
1995            | HirRelationExpr::Get { .. }
1996            | HirRelationExpr::CallTable { .. } => (),
1997            HirRelationExpr::Let { body, value, .. } => {
1998                f(value, depth)?;
1999                f(body, depth)?;
2000            }
2001            HirRelationExpr::LetRec {
2002                limit: _,
2003                bindings,
2004                body,
2005            } => {
2006                for (_, _, value, _) in bindings.iter() {
2007                    f(value, depth)?;
2008                }
2009                f(body, depth)?;
2010            }
2011            HirRelationExpr::Project { input, .. } => {
2012                f(input, depth)?;
2013            }
2014            HirRelationExpr::Map { input, .. } => {
2015                f(input, depth)?;
2016            }
2017            HirRelationExpr::Filter { input, .. } => {
2018                f(input, depth)?;
2019            }
2020            HirRelationExpr::Join { left, right, .. } => {
2021                f(left, depth)?;
2022                f(right, depth + 1)?;
2023            }
2024            HirRelationExpr::Reduce { input, .. } => {
2025                f(input, depth)?;
2026            }
2027            HirRelationExpr::Distinct { input } => {
2028                f(input, depth)?;
2029            }
2030            HirRelationExpr::TopK { input, .. } => {
2031                f(input, depth)?;
2032            }
2033            HirRelationExpr::Negate { input } => {
2034                f(input, depth)?;
2035            }
2036            HirRelationExpr::Threshold { input } => {
2037                f(input, depth)?;
2038            }
2039            HirRelationExpr::Union { base, inputs } => {
2040                f(base, depth)?;
2041                for input in inputs {
2042                    f(input, depth)?;
2043                }
2044            }
2045        }
2046        Ok(())
2047    }
2048
2049    #[deprecated = "Use `Visit::visit_mut_post` instead."]
2050    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2051    where
2052        F: FnMut(&mut Self, usize),
2053    {
2054        #[allow(deprecated)]
2055        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2056                                                     depth: usize|
2057         -> Result<(), ()> {
2058            f(e, depth);
2059            Ok(())
2060        });
2061    }
2062
2063    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2064    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2065    where
2066        F: FnMut(&mut Self, usize) -> Result<(), E>,
2067    {
2068        #[allow(deprecated)]
2069        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2070            e.visit_mut_fallible(depth, f)
2071        })?;
2072        f(self, depth)
2073    }
2074
2075    /// WARNING: `VisitChildren<HirRelationExpr>::try_visit_mut_children` is NOT a
2076    /// drop-in replacement: in addition to the relation children visited here,
2077    /// it also descends into every subquery (`Exists`/`Select`) reachable
2078    /// through `self`'s scalar children — at any depth of scalar nesting.
2079    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2080    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2081    where
2082        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2083    {
2084        match self {
2085            HirRelationExpr::Constant { .. }
2086            | HirRelationExpr::Get { .. }
2087            | HirRelationExpr::CallTable { .. } => (),
2088            HirRelationExpr::Let { body, value, .. } => {
2089                f(value, depth)?;
2090                f(body, depth)?;
2091            }
2092            HirRelationExpr::LetRec {
2093                limit: _,
2094                bindings,
2095                body,
2096            } => {
2097                for (_, _, value, _) in bindings.iter_mut() {
2098                    f(value, depth)?;
2099                }
2100                f(body, depth)?;
2101            }
2102            HirRelationExpr::Project { input, .. } => {
2103                f(input, depth)?;
2104            }
2105            HirRelationExpr::Map { input, .. } => {
2106                f(input, depth)?;
2107            }
2108            HirRelationExpr::Filter { input, .. } => {
2109                f(input, depth)?;
2110            }
2111            HirRelationExpr::Join { left, right, .. } => {
2112                f(left, depth)?;
2113                f(right, depth + 1)?;
2114            }
2115            HirRelationExpr::Reduce { input, .. } => {
2116                f(input, depth)?;
2117            }
2118            HirRelationExpr::Distinct { input } => {
2119                f(input, depth)?;
2120            }
2121            HirRelationExpr::TopK { input, .. } => {
2122                f(input, depth)?;
2123            }
2124            HirRelationExpr::Negate { input } => {
2125                f(input, depth)?;
2126            }
2127            HirRelationExpr::Threshold { input } => {
2128                f(input, depth)?;
2129            }
2130            HirRelationExpr::Union { base, inputs } => {
2131                f(base, depth)?;
2132                for input in inputs {
2133                    f(input, depth)?;
2134                }
2135            }
2136        }
2137        Ok(())
2138    }
2139
2140    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2141    /// Visits all scalar expressions directly held by relation nodes within the sub-tree of `self`.
2142    ///
2143    /// Note: this does NOT descend into subqueries that may appear inside the visited
2144    /// `HirScalarExpr`s (i.e., `HirScalarExpr::Exists` / `HirScalarExpr::Select`). The closure
2145    /// `f` is invoked once per top-level scalar expression attached to a relation node, and it
2146    /// is the closure's responsibility to recurse into any subqueries if desired.
2147    ///
2148    /// The `depth` argument is just a seed: it is the value passed to `f` for scalar expressions
2149    /// at the root of `self`, and it is incremented by 1 when descending into the RHS of a
2150    /// `Join` node (the only place this function increments it). It does NOT control or limit
2151    /// recursion; passing `0` is the usual choice.
2152    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2153    where
2154        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2155    {
2156        #[allow(deprecated)]
2157        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2158                                         depth: usize|
2159         -> Result<(), E> {
2160            match e {
2161                HirRelationExpr::Join { on, .. } => {
2162                    f(on, depth)?;
2163                }
2164                HirRelationExpr::Map { scalars, .. } => {
2165                    for scalar in scalars {
2166                        f(scalar, depth)?;
2167                    }
2168                }
2169                HirRelationExpr::CallTable { exprs, .. } => {
2170                    for expr in exprs {
2171                        f(expr, depth)?;
2172                    }
2173                }
2174                HirRelationExpr::Filter { predicates, .. } => {
2175                    for predicate in predicates {
2176                        f(predicate, depth)?;
2177                    }
2178                }
2179                HirRelationExpr::Reduce { aggregates, .. } => {
2180                    for aggregate in aggregates {
2181                        f(&aggregate.expr, depth)?;
2182                    }
2183                }
2184                HirRelationExpr::TopK { limit, offset, .. } => {
2185                    if let Some(limit) = limit {
2186                        f(limit, depth)?;
2187                    }
2188                    f(offset, depth)?;
2189                }
2190                HirRelationExpr::Union { .. }
2191                | HirRelationExpr::Let { .. }
2192                | HirRelationExpr::LetRec { .. }
2193                | HirRelationExpr::Project { .. }
2194                | HirRelationExpr::Distinct { .. }
2195                | HirRelationExpr::Negate { .. }
2196                | HirRelationExpr::Threshold { .. }
2197                | HirRelationExpr::Constant { .. }
2198                | HirRelationExpr::Get { .. } => (),
2199            }
2200            Ok(())
2201        })
2202    }
2203
2204    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2205    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2206    ///
2207    /// In particular, this also does NOT descend into subqueries inside the visited scalar
2208    /// expressions, and `depth` is just a seed for the value passed to `f` (see
2209    /// [`HirRelationExpr::visit_scalar_expressions`]).
2210    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2211    where
2212        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2213    {
2214        #[allow(deprecated)]
2215        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2216                                             depth: usize|
2217         -> Result<(), E> {
2218            match e {
2219                HirRelationExpr::Join { on, .. } => {
2220                    f(on, depth)?;
2221                }
2222                HirRelationExpr::Map { scalars, .. } => {
2223                    for scalar in scalars.iter_mut() {
2224                        f(scalar, depth)?;
2225                    }
2226                }
2227                HirRelationExpr::CallTable { exprs, .. } => {
2228                    for expr in exprs.iter_mut() {
2229                        f(expr, depth)?;
2230                    }
2231                }
2232                HirRelationExpr::Filter { predicates, .. } => {
2233                    for predicate in predicates.iter_mut() {
2234                        f(predicate, depth)?;
2235                    }
2236                }
2237                HirRelationExpr::Reduce { aggregates, .. } => {
2238                    for aggregate in aggregates.iter_mut() {
2239                        f(&mut aggregate.expr, depth)?;
2240                    }
2241                }
2242                HirRelationExpr::TopK { limit, offset, .. } => {
2243                    if let Some(limit) = limit {
2244                        f(limit, depth)?;
2245                    }
2246                    f(offset, depth)?;
2247                }
2248                HirRelationExpr::Union { .. }
2249                | HirRelationExpr::Let { .. }
2250                | HirRelationExpr::LetRec { .. }
2251                | HirRelationExpr::Project { .. }
2252                | HirRelationExpr::Distinct { .. }
2253                | HirRelationExpr::Negate { .. }
2254                | HirRelationExpr::Threshold { .. }
2255                | HirRelationExpr::Constant { .. }
2256                | HirRelationExpr::Get { .. } => (),
2257            }
2258            Ok(())
2259        })
2260    }
2261
2262    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2263    /// Visits the column references in this relation expression.
2264    ///
2265    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2266    /// which will be incremented when entering the RHS of a join or a subquery and
2267    /// presented to the supplied function `f`.
2268    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2269    where
2270        F: FnMut(usize, &ColumnRef),
2271    {
2272        #[allow(deprecated)]
2273        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2274                                                           depth: usize|
2275         -> Result<(), ()> {
2276            e.visit_columns(depth, f);
2277            Ok(())
2278        });
2279    }
2280
2281    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2282    /// Like `visit_columns`, but permits mutating the column references.
2283    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2284    where
2285        F: FnMut(usize, &mut ColumnRef),
2286    {
2287        #[allow(deprecated)]
2288        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2289                                                               depth: usize|
2290         -> Result<(), ()> {
2291            e.visit_columns_mut(depth, f);
2292            Ok(())
2293        });
2294    }
2295
2296    /// Replaces parameter references in the expression with the corresponding datum from `params`.
2297    /// Additionally, it simplifies OFFSET clauses to constants after parameter binding, and checks
2298    /// them for non-negativity.
2299    ///
2300    /// This walks the entire `HirRelationExpr` tree, including subqueries nested inside scalar
2301    /// expressions: `visit_scalar_expressions_mut` covers the scalar expressions held directly
2302    /// by relation nodes, and the corecursive call into
2303    /// [`HirScalarExpr::bind_parameters_and_simplify_offset`] (made by the closure below) is
2304    /// what descends into subqueries occurring inside those scalar expressions.
2305    pub fn bind_parameters_and_simplify_offset(
2306        &mut self,
2307        scx: &StatementContext,
2308        lifetime: QueryLifetime,
2309        params: &Params,
2310    ) -> Result<(), PlanError> {
2311        #[allow(deprecated)]
2312        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2313            e.bind_parameters_and_simplify_offset(scx, lifetime, params)
2314        })?;
2315
2316        // OFFSET clauses in `expr` should become constants with the above binding of parameters.
2317        // Let's check this and simplify them to literals.
2318        self.try_visit_mut_pre(&mut |expr| {
2319            if let HirRelationExpr::TopK { offset, .. } = expr {
2320                let offset_value = offset_into_value(offset.take())?;
2321                *offset = HirScalarExpr::literal(Datum::Int64(offset_value), SqlScalarType::Int64);
2322            }
2323            Ok::<(), PlanError>(())
2324        })
2325        // (We don't need to simplify LIMIT clauses in `expr`, because we can handle non-constant
2326        // expressions there. If they happen to be simplifiable to literals, then the optimizer will do
2327        // so later.)
2328    }
2329
2330    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2331        let mut contains_parameters = false;
2332        #[allow(deprecated)]
2333        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2334            if e.contains_parameters() {
2335                contains_parameters = true;
2336            }
2337            Ok::<(), PlanError>(())
2338        })?;
2339        Ok(contains_parameters)
2340    }
2341
2342    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2343    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2344        #[allow(deprecated)]
2345        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2346                                                               depth: usize|
2347         -> Result<(), ()> {
2348            e.splice_parameters(params, depth);
2349            Ok(())
2350        });
2351    }
2352
2353    /// Constructs a constant collection from specific rows and schema.
2354    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2355        let rows = rows
2356            .into_iter()
2357            .map(move |datums| Row::pack_slice(&datums))
2358            .collect();
2359        HirRelationExpr::Constant { rows, typ }
2360    }
2361
2362    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2363    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2364    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2365    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2366    /// finishing into a trivial finishing.
2367    pub fn finish_maintained(
2368        &mut self,
2369        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2370        group_size_hints: GroupSizeHints,
2371    ) {
2372        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2373            let old_finishing = mem::replace(
2374                finishing,
2375                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2376            );
2377            *self = HirRelationExpr::top_k(
2378                std::mem::replace(
2379                    self,
2380                    HirRelationExpr::Constant {
2381                        rows: vec![],
2382                        typ: SqlRelationType::new(Vec::new()),
2383                    },
2384                ),
2385                vec![],
2386                old_finishing.order_by,
2387                old_finishing.limit,
2388                old_finishing.offset,
2389                group_size_hints.limit_input_group_size,
2390            )
2391            .project(old_finishing.project);
2392        }
2393    }
2394
2395    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2396    ///
2397    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2398    /// parameter is not an HirScalarExpr anymore.)
2399    pub fn trivial_row_set_finishing_hir(
2400        arity: usize,
2401    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2402        RowSetFinishing {
2403            order_by: Vec::new(),
2404            limit: None,
2405            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2406            project: (0..arity).collect(),
2407        }
2408    }
2409
2410    /// True if the finishing does nothing to any result set.
2411    ///
2412    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2413    /// parameter is not an HirScalarExpr anymore.)
2414    pub fn is_trivial_row_set_finishing_hir(
2415        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2416        arity: usize,
2417    ) -> bool {
2418        rsf.limit.is_none()
2419            && rsf.order_by.is_empty()
2420            && rsf
2421                .offset
2422                .clone()
2423                .try_into_literal_int64()
2424                .is_ok_and(|o| o == 0)
2425            && rsf.project.iter().copied().eq(0..arity)
2426    }
2427
2428    /// The HirRelationExpr is considered potentially expensive if and only if
2429    /// at least one of the following conditions is true:
2430    ///
2431    ///  - It contains at least one HirScalarExpr with a function call.
2432    ///  - It contains at least one CallTable or a Reduce operator.
2433    ///  - We run into a RecursionLimitError while analyzing the expression.
2434    ///
2435    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2436    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2437    pub fn could_run_expensive_function(&self) -> bool {
2438        let mut result = false;
2439        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2440            use HirRelationExpr::*;
2441            use HirScalarExpr::*;
2442
2443            e.visit_children(|scalar: &HirScalarExpr| {
2444                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2445                    result |= match scalar {
2446                        Column(..)
2447                        | Literal(..)
2448                        | CallUnmaterializable(..)
2449                        | If { .. }
2450                        | Parameter(..)
2451                        | Select(..)
2452                        | Exists(..) => false,
2453                        // Function calls are considered expensive
2454                        CallUnary { .. }
2455                        | CallBinary { .. }
2456                        | CallVariadic { .. }
2457                        | Windowing(..) => true,
2458                    };
2459                }) {
2460                    // Conservatively set `true` on RecursionLimitError.
2461                    result = true;
2462                }
2463            });
2464
2465            // CallTable has a table function; Reduce has an aggregate function.
2466            // Other constructs use MirScalarExpr to run a function
2467            result |= matches!(e, CallTable { .. } | Reduce { .. });
2468        }) {
2469            // Conservatively set `true` on RecursionLimitError.
2470            result = true;
2471        }
2472
2473        result
2474    }
2475
2476    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2477    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2478        let mut contains = false;
2479        self.visit_post(&mut |expr| {
2480            expr.visit_children(|expr: &HirScalarExpr| {
2481                contains = contains || expr.contains_temporal()
2482            })
2483        })?;
2484        Ok(contains)
2485    }
2486
2487    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2488    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2489        let mut contains = false;
2490        self.visit_post(&mut |expr| {
2491            expr.visit_children(|expr: &HirScalarExpr| {
2492                contains = contains || expr.contains_unmaterializable()
2493            })
2494        })?;
2495        Ok(contains)
2496    }
2497
2498    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
2499    /// [`UnmaterializableFunc::MzNow`].
2500    pub fn contains_unmaterializable_except_temporal(&self) -> Result<bool, RecursionLimitError> {
2501        let mut contains = false;
2502        self.visit_post(&mut |expr| {
2503            expr.visit_children(|expr: &HirScalarExpr| {
2504                contains = contains || expr.contains_unmaterializable_except_temporal()
2505            })
2506        })?;
2507        Ok(contains)
2508    }
2509}
2510
2511impl CollectionPlan for HirRelationExpr {
2512    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2513    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2514    /// (It does explore inside subqueries.)
2515    ///
2516    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2517    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2518    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2519        if let Self::Get {
2520            id: Id::Global(id), ..
2521        } = self
2522        {
2523            out.insert(*id);
2524        }
2525        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2526    }
2527}
2528
2529/// In addition to direct relation children of `self`, this also yields every
2530/// subquery (`Exists` / `Select`) reachable through `self`'s scalar children,
2531/// at any depth of scalar nesting. This is the asymmetry warned about on the
2532/// [`VisitChildren`] trait; the matching impl on `HirScalarExpr` deliberately
2533/// does not do the symmetric thing, to avoid mutual recursion.
2534impl VisitChildren<Self> for HirRelationExpr {
2535    fn visit_children<F>(&self, mut f: F)
2536    where
2537        F: FnMut(&Self),
2538    {
2539        // subqueries of type HirRelationExpr might be wrapped in
2540        // Exists or Select variants within HirScalarExpr trees
2541        // attached at the current node, and we want to visit them as well
2542        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2543            expr.visit_direct_subqueries(&mut f);
2544        });
2545
2546        use HirRelationExpr::*;
2547        match self {
2548            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2549            Let {
2550                name: _,
2551                id: _,
2552                value,
2553                body,
2554            } => {
2555                f(value);
2556                f(body);
2557            }
2558            LetRec {
2559                limit: _,
2560                bindings,
2561                body,
2562            } => {
2563                for (_, _, value, _) in bindings.iter() {
2564                    f(value);
2565                }
2566                f(body);
2567            }
2568            Project { input, outputs: _ } => f(input),
2569            Map { input, scalars: _ } => {
2570                f(input);
2571            }
2572            CallTable { func: _, exprs: _ } => (),
2573            Filter {
2574                input,
2575                predicates: _,
2576            } => {
2577                f(input);
2578            }
2579            Join {
2580                left,
2581                right,
2582                on: _,
2583                kind: _,
2584            } => {
2585                f(left);
2586                f(right);
2587            }
2588            Reduce {
2589                input,
2590                group_key: _,
2591                aggregates: _,
2592                expected_group_size: _,
2593            } => {
2594                f(input);
2595            }
2596            Distinct { input }
2597            | TopK {
2598                input,
2599                group_key: _,
2600                order_key: _,
2601                limit: _,
2602                offset: _,
2603                expected_group_size: _,
2604            }
2605            | Negate { input }
2606            | Threshold { input } => {
2607                f(input);
2608            }
2609            Union { base, inputs } => {
2610                f(base);
2611                for input in inputs {
2612                    f(input);
2613                }
2614            }
2615        }
2616    }
2617
2618    fn visit_mut_children<F>(&mut self, mut f: F)
2619    where
2620        F: FnMut(&mut Self),
2621    {
2622        // subqueries of type HirRelationExpr might be wrapped in
2623        // Exists or Select variants within HirScalarExpr trees
2624        // attached at the current node, and we want to visit them as well
2625        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2626            expr.visit_direct_subqueries_mut(&mut f);
2627        });
2628
2629        use HirRelationExpr::*;
2630        match self {
2631            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2632            Let {
2633                name: _,
2634                id: _,
2635                value,
2636                body,
2637            } => {
2638                f(value);
2639                f(body);
2640            }
2641            LetRec {
2642                limit: _,
2643                bindings,
2644                body,
2645            } => {
2646                for (_, _, value, _) in bindings.iter_mut() {
2647                    f(value);
2648                }
2649                f(body);
2650            }
2651            Project { input, outputs: _ } => f(input),
2652            Map { input, scalars: _ } => {
2653                f(input);
2654            }
2655            CallTable { func: _, exprs: _ } => (),
2656            Filter {
2657                input,
2658                predicates: _,
2659            } => {
2660                f(input);
2661            }
2662            Join {
2663                left,
2664                right,
2665                on: _,
2666                kind: _,
2667            } => {
2668                f(left);
2669                f(right);
2670            }
2671            Reduce {
2672                input,
2673                group_key: _,
2674                aggregates: _,
2675                expected_group_size: _,
2676            } => {
2677                f(input);
2678            }
2679            Distinct { input }
2680            | TopK {
2681                input,
2682                group_key: _,
2683                order_key: _,
2684                limit: _,
2685                offset: _,
2686                expected_group_size: _,
2687            }
2688            | Negate { input }
2689            | Threshold { input } => {
2690                f(input);
2691            }
2692            Union { base, inputs } => {
2693                f(base);
2694                for input in inputs {
2695                    f(input);
2696                }
2697            }
2698        }
2699    }
2700
2701    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2702    where
2703        F: FnMut(&Self) -> Result<(), E>,
2704        E: From<RecursionLimitError>,
2705    {
2706        // subqueries of type HirRelationExpr might be wrapped in
2707        // Exists or Select variants within HirScalarExpr trees
2708        // attached at the current node, and we want to visit them as well
2709        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2710            expr.try_visit_direct_subqueries(&mut f)
2711        })?;
2712
2713        use HirRelationExpr::*;
2714        match self {
2715            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2716            Let {
2717                name: _,
2718                id: _,
2719                value,
2720                body,
2721            } => {
2722                f(value)?;
2723                f(body)?;
2724            }
2725            LetRec {
2726                limit: _,
2727                bindings,
2728                body,
2729            } => {
2730                for (_, _, value, _) in bindings.iter() {
2731                    f(value)?;
2732                }
2733                f(body)?;
2734            }
2735            Project { input, outputs: _ } => f(input)?,
2736            Map { input, scalars: _ } => {
2737                f(input)?;
2738            }
2739            CallTable { func: _, exprs: _ } => (),
2740            Filter {
2741                input,
2742                predicates: _,
2743            } => {
2744                f(input)?;
2745            }
2746            Join {
2747                left,
2748                right,
2749                on: _,
2750                kind: _,
2751            } => {
2752                f(left)?;
2753                f(right)?;
2754            }
2755            Reduce {
2756                input,
2757                group_key: _,
2758                aggregates: _,
2759                expected_group_size: _,
2760            } => {
2761                f(input)?;
2762            }
2763            Distinct { input }
2764            | TopK {
2765                input,
2766                group_key: _,
2767                order_key: _,
2768                limit: _,
2769                offset: _,
2770                expected_group_size: _,
2771            }
2772            | Negate { input }
2773            | Threshold { input } => {
2774                f(input)?;
2775            }
2776            Union { base, inputs } => {
2777                f(base)?;
2778                for input in inputs {
2779                    f(input)?;
2780                }
2781            }
2782        }
2783        Ok(())
2784    }
2785
2786    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2787    where
2788        F: FnMut(&mut Self) -> Result<(), E>,
2789        E: From<RecursionLimitError>,
2790    {
2791        // subqueries of type HirRelationExpr might be wrapped in
2792        // Exists or Select variants within HirScalarExpr trees
2793        // attached at the current node, and we want to visit them as well
2794        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2795            expr.try_visit_direct_subqueries_mut(&mut f)
2796        })?;
2797
2798        use HirRelationExpr::*;
2799        match self {
2800            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2801            Let {
2802                name: _,
2803                id: _,
2804                value,
2805                body,
2806            } => {
2807                f(value)?;
2808                f(body)?;
2809            }
2810            LetRec {
2811                limit: _,
2812                bindings,
2813                body,
2814            } => {
2815                for (_, _, value, _) in bindings.iter_mut() {
2816                    f(value)?;
2817                }
2818                f(body)?;
2819            }
2820            Project { input, outputs: _ } => f(input)?,
2821            Map { input, scalars: _ } => {
2822                f(input)?;
2823            }
2824            CallTable { func: _, exprs: _ } => (),
2825            Filter {
2826                input,
2827                predicates: _,
2828            } => {
2829                f(input)?;
2830            }
2831            Join {
2832                left,
2833                right,
2834                on: _,
2835                kind: _,
2836            } => {
2837                f(left)?;
2838                f(right)?;
2839            }
2840            Reduce {
2841                input,
2842                group_key: _,
2843                aggregates: _,
2844                expected_group_size: _,
2845            } => {
2846                f(input)?;
2847            }
2848            Distinct { input }
2849            | TopK {
2850                input,
2851                group_key: _,
2852                order_key: _,
2853                limit: _,
2854                offset: _,
2855                expected_group_size: _,
2856            }
2857            | Negate { input }
2858            | Threshold { input } => {
2859                f(input)?;
2860            }
2861            Union { base, inputs } => {
2862                f(base)?;
2863                for input in inputs {
2864                    f(input)?;
2865                }
2866            }
2867        }
2868        Ok(())
2869    }
2870}
2871
2872/// Yields the scalars directly attached to relation nodes (e.g. `Map.scalars`,
2873/// `Filter.predicates`, `Join.on`, `Reduce` aggregate args, `TopK.{limit,
2874/// offset}`, `CallTable.exprs`); does not descend into them.
2875impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2876    fn visit_children<F>(&self, mut f: F)
2877    where
2878        F: FnMut(&HirScalarExpr),
2879    {
2880        use HirRelationExpr::*;
2881        match self {
2882            Constant { rows: _, typ: _ }
2883            | Get { id: _, typ: _ }
2884            | Let {
2885                name: _,
2886                id: _,
2887                value: _,
2888                body: _,
2889            }
2890            | LetRec {
2891                limit: _,
2892                bindings: _,
2893                body: _,
2894            }
2895            | Project {
2896                input: _,
2897                outputs: _,
2898            } => (),
2899            Map { input: _, scalars } => {
2900                for scalar in scalars {
2901                    f(scalar);
2902                }
2903            }
2904            CallTable { func: _, exprs } => {
2905                for expr in exprs {
2906                    f(expr);
2907                }
2908            }
2909            Filter {
2910                input: _,
2911                predicates,
2912            } => {
2913                for predicate in predicates {
2914                    f(predicate);
2915                }
2916            }
2917            Join {
2918                left: _,
2919                right: _,
2920                on,
2921                kind: _,
2922            } => f(on),
2923            Reduce {
2924                input: _,
2925                group_key: _,
2926                aggregates,
2927                expected_group_size: _,
2928            } => {
2929                for aggregate in aggregates {
2930                    f(aggregate.expr.as_ref());
2931                }
2932            }
2933            TopK {
2934                input: _,
2935                group_key: _,
2936                order_key: _,
2937                limit,
2938                offset,
2939                expected_group_size: _,
2940            } => {
2941                if let Some(limit) = limit {
2942                    f(limit)
2943                }
2944                f(offset)
2945            }
2946            Distinct { input: _ }
2947            | Negate { input: _ }
2948            | Threshold { input: _ }
2949            | Union { base: _, inputs: _ } => (),
2950        }
2951    }
2952
2953    fn visit_mut_children<F>(&mut self, mut f: F)
2954    where
2955        F: FnMut(&mut HirScalarExpr),
2956    {
2957        use HirRelationExpr::*;
2958        match self {
2959            Constant { rows: _, typ: _ }
2960            | Get { id: _, typ: _ }
2961            | Let {
2962                name: _,
2963                id: _,
2964                value: _,
2965                body: _,
2966            }
2967            | LetRec {
2968                limit: _,
2969                bindings: _,
2970                body: _,
2971            }
2972            | Project {
2973                input: _,
2974                outputs: _,
2975            } => (),
2976            Map { input: _, scalars } => {
2977                for scalar in scalars {
2978                    f(scalar);
2979                }
2980            }
2981            CallTable { func: _, exprs } => {
2982                for expr in exprs {
2983                    f(expr);
2984                }
2985            }
2986            Filter {
2987                input: _,
2988                predicates,
2989            } => {
2990                for predicate in predicates {
2991                    f(predicate);
2992                }
2993            }
2994            Join {
2995                left: _,
2996                right: _,
2997                on,
2998                kind: _,
2999            } => f(on),
3000            Reduce {
3001                input: _,
3002                group_key: _,
3003                aggregates,
3004                expected_group_size: _,
3005            } => {
3006                for aggregate in aggregates {
3007                    f(aggregate.expr.as_mut());
3008                }
3009            }
3010            TopK {
3011                input: _,
3012                group_key: _,
3013                order_key: _,
3014                limit,
3015                offset,
3016                expected_group_size: _,
3017            } => {
3018                if let Some(limit) = limit {
3019                    f(limit)
3020                }
3021                f(offset)
3022            }
3023            Distinct { input: _ }
3024            | Negate { input: _ }
3025            | Threshold { input: _ }
3026            | Union { base: _, inputs: _ } => (),
3027        }
3028    }
3029
3030    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3031    where
3032        F: FnMut(&HirScalarExpr) -> Result<(), E>,
3033        E: From<RecursionLimitError>,
3034    {
3035        use HirRelationExpr::*;
3036        match self {
3037            Constant { rows: _, typ: _ }
3038            | Get { id: _, typ: _ }
3039            | Let {
3040                name: _,
3041                id: _,
3042                value: _,
3043                body: _,
3044            }
3045            | LetRec {
3046                limit: _,
3047                bindings: _,
3048                body: _,
3049            }
3050            | Project {
3051                input: _,
3052                outputs: _,
3053            } => (),
3054            Map { input: _, scalars } => {
3055                for scalar in scalars {
3056                    f(scalar)?;
3057                }
3058            }
3059            CallTable { func: _, exprs } => {
3060                for expr in exprs {
3061                    f(expr)?;
3062                }
3063            }
3064            Filter {
3065                input: _,
3066                predicates,
3067            } => {
3068                for predicate in predicates {
3069                    f(predicate)?;
3070                }
3071            }
3072            Join {
3073                left: _,
3074                right: _,
3075                on,
3076                kind: _,
3077            } => f(on)?,
3078            Reduce {
3079                input: _,
3080                group_key: _,
3081                aggregates,
3082                expected_group_size: _,
3083            } => {
3084                for aggregate in aggregates {
3085                    f(aggregate.expr.as_ref())?;
3086                }
3087            }
3088            TopK {
3089                input: _,
3090                group_key: _,
3091                order_key: _,
3092                limit,
3093                offset,
3094                expected_group_size: _,
3095            } => {
3096                if let Some(limit) = limit {
3097                    f(limit)?
3098                }
3099                f(offset)?
3100            }
3101            Distinct { input: _ }
3102            | Negate { input: _ }
3103            | Threshold { input: _ }
3104            | Union { base: _, inputs: _ } => (),
3105        }
3106        Ok(())
3107    }
3108
3109    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3110    where
3111        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3112        E: From<RecursionLimitError>,
3113    {
3114        use HirRelationExpr::*;
3115        match self {
3116            Constant { rows: _, typ: _ }
3117            | Get { id: _, typ: _ }
3118            | Let {
3119                name: _,
3120                id: _,
3121                value: _,
3122                body: _,
3123            }
3124            | LetRec {
3125                limit: _,
3126                bindings: _,
3127                body: _,
3128            }
3129            | Project {
3130                input: _,
3131                outputs: _,
3132            } => (),
3133            Map { input: _, scalars } => {
3134                for scalar in scalars {
3135                    f(scalar)?;
3136                }
3137            }
3138            CallTable { func: _, exprs } => {
3139                for expr in exprs {
3140                    f(expr)?;
3141                }
3142            }
3143            Filter {
3144                input: _,
3145                predicates,
3146            } => {
3147                for predicate in predicates {
3148                    f(predicate)?;
3149                }
3150            }
3151            Join {
3152                left: _,
3153                right: _,
3154                on,
3155                kind: _,
3156            } => f(on)?,
3157            Reduce {
3158                input: _,
3159                group_key: _,
3160                aggregates,
3161                expected_group_size: _,
3162            } => {
3163                for aggregate in aggregates {
3164                    f(aggregate.expr.as_mut())?;
3165                }
3166            }
3167            TopK {
3168                input: _,
3169                group_key: _,
3170                order_key: _,
3171                limit,
3172                offset,
3173                expected_group_size: _,
3174            } => {
3175                if let Some(limit) = limit {
3176                    f(limit)?
3177                }
3178                f(offset)?
3179            }
3180            Distinct { input: _ }
3181            | Negate { input: _ }
3182            | Threshold { input: _ }
3183            | Union { base: _, inputs: _ } => (),
3184        }
3185        Ok(())
3186    }
3187}
3188
3189impl HirScalarExpr {
3190    pub fn name(&self) -> Option<Arc<str>> {
3191        use HirScalarExpr::*;
3192        match self {
3193            Column(_, name)
3194            | Parameter(_, name)
3195            | Literal(_, _, name)
3196            | CallUnmaterializable(_, name)
3197            | CallUnary { name, .. }
3198            | CallBinary { name, .. }
3199            | CallVariadic { name, .. }
3200            | If { name, .. }
3201            | Exists(_, name)
3202            | Select(_, name)
3203            | Windowing(_, name) => name.0.clone(),
3204        }
3205    }
3206
3207    /// Visit every subquery, without descending into subqueries.
3208    pub fn visit_direct_subqueries<F>(&self, mut f: F)
3209    where
3210        F: FnMut(&HirRelationExpr),
3211    {
3212        // The infallible variants of the post-walk are deprecated in favor
3213        // of the limit-aware `try_visit_post`, but we have no error channel
3214        // here (the surrounding signature is infallible, mirroring the
3215        // infallible `VisitChildren::{visit_children, visit_mut_children}`
3216        // impls that this helper is designed to be called from).
3217        #[allow(deprecated)]
3218        self.visit_post_nolimit(&mut |e| {
3219            VisitChildren::<HirRelationExpr>::visit_children(e, &mut f);
3220        });
3221    }
3222
3223    /// Mutable counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3224    pub fn visit_direct_subqueries_mut<F>(&mut self, mut f: F)
3225    where
3226        F: FnMut(&mut HirRelationExpr),
3227    {
3228        // See the comment on `visit_direct_subqueries` for why we use the
3229        // deprecated `_nolimit` walk here.
3230        #[allow(deprecated)]
3231        self.visit_mut_post_nolimit(&mut |e| {
3232            VisitChildren::<HirRelationExpr>::visit_mut_children(e, &mut f);
3233        });
3234    }
3235
3236    /// Fallible counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3237    pub fn try_visit_direct_subqueries<F, E>(&self, mut f: F) -> Result<(), E>
3238    where
3239        F: FnMut(&HirRelationExpr) -> Result<(), E>,
3240        E: From<RecursionLimitError>,
3241    {
3242        self.try_visit_post(&mut |e| {
3243            VisitChildren::<HirRelationExpr>::try_visit_children(e, &mut f)
3244        })
3245    }
3246
3247    /// Fallible mutable counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3248    pub fn try_visit_direct_subqueries_mut<F, E>(&mut self, mut f: F) -> Result<(), E>
3249    where
3250        F: FnMut(&mut HirRelationExpr) -> Result<(), E>,
3251        E: From<RecursionLimitError>,
3252    {
3253        self.try_visit_mut_post(&mut |e| {
3254            VisitChildren::<HirRelationExpr>::try_visit_mut_children(e, &mut f)
3255        })
3256    }
3257
3258    /// Replaces parameter references in the expression with the corresponding datum from `params`.
3259    /// Additionally, it simplifies OFFSET clauses to constants after parameter binding, and checks
3260    /// them for non-negativity.
3261    ///
3262    /// This handles subqueries nested inside `self` by calling back into
3263    /// [`HirRelationExpr::bind_parameters_and_simplify_offset`] on each direct subquery; that
3264    /// relation-level function in turn calls back here for the scalar expressions it holds, so
3265    /// the two functions corecursively cover the entire HIR tree.
3266    pub fn bind_parameters_and_simplify_offset(
3267        &mut self,
3268        scx: &StatementContext,
3269        lifetime: QueryLifetime,
3270        params: &Params,
3271    ) -> Result<(), PlanError> {
3272        // First, rewrite each `Parameter` node to a literal. This walks only the scalar tree;
3273        // it does not yet descend into subqueries.
3274        self.try_visit_mut_post(&mut |e: &mut HirScalarExpr| {
3275            if let HirScalarExpr::Parameter(n, name) = e {
3276                let datum = match params.datums.iter().nth(*n - 1) {
3277                    None => return Err(PlanError::UnknownParameter(*n)),
3278                    Some(datum) => datum,
3279                };
3280                let scalar_type = &params.execute_types[*n - 1];
3281                let row = Row::pack([datum]);
3282                let column_type = scalar_type.clone().nullable(datum.is_null());
3283
3284                let name = if let Some(name) = &name.0 {
3285                    Some(Arc::clone(name))
3286                } else {
3287                    Some(Arc::from(format!("${n}")))
3288                };
3289
3290                let qcx = QueryContext::root(scx, lifetime);
3291                let ecx = execute_expr_context(&qcx);
3292
3293                *e = plan_cast(
3294                    &ecx,
3295                    *EXECUTE_CAST_CONTEXT,
3296                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3297                    &params.expected_types[*n - 1],
3298                )
3299                .expect("checked in plan_params");
3300            }
3301            Ok(())
3302        })?;
3303        // Then descend into any subqueries; the relation-side `bind_parameters_and_simplify_offset`
3304        // handles corecursion back into scalars.
3305        self.try_visit_direct_subqueries_mut(|r: &mut HirRelationExpr| {
3306            r.bind_parameters_and_simplify_offset(scx, lifetime, params)
3307        })
3308    }
3309
3310    /// Like [`HirScalarExpr::bind_parameters_and_simplify_offset`], except that parameters are
3311    /// replaced with the corresponding expression fragment from `params` rather
3312    /// than a datum.
3313    ///
3314    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3315    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3316    /// in `self` that refer to invalid indices of `params` will cause a panic.
3317    ///
3318    /// Column references in parameters will be corrected to account for the
3319    /// depth at which they are spliced.
3320    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3321        #[allow(deprecated)]
3322        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3323                                                        e: &mut HirScalarExpr|
3324         -> Result<(), ()> {
3325            if let HirScalarExpr::Parameter(i, _name) = e {
3326                *e = params[*i - 1].clone();
3327                // Correct any column references in the parameter expression for
3328                // its new depth.
3329                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3330                    if col.level >= d {
3331                        col.level += depth
3332                    }
3333                });
3334            }
3335            Ok(())
3336        });
3337    }
3338
3339    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3340    pub fn contains_temporal(&self) -> bool {
3341        let mut contains = false;
3342        #[allow(deprecated)]
3343        self.visit_post_nolimit(&mut |e| {
3344            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3345                contains = true;
3346            }
3347        });
3348        contains
3349    }
3350
3351    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3352    pub fn contains_unmaterializable(&self) -> bool {
3353        let mut contains = false;
3354        #[allow(deprecated)]
3355        self.visit_post_nolimit(&mut |e| {
3356            if let Self::CallUnmaterializable(_, _) = e {
3357                contains = true;
3358            }
3359        });
3360        contains
3361    }
3362
3363    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
3364    /// [`UnmaterializableFunc::MzNow`].
3365    pub fn contains_unmaterializable_except_temporal(&self) -> bool {
3366        let mut contains = false;
3367        #[allow(deprecated)]
3368        self.visit_post_nolimit(&mut |e| {
3369            if let Self::CallUnmaterializable(f, _) = e {
3370                if *f != UnmaterializableFunc::MzNow {
3371                    contains = true;
3372                }
3373            }
3374        });
3375        contains
3376    }
3377
3378    /// Constructs an unnamed column reference in the current scope.
3379    /// Use [`HirScalarExpr::named_column`] when a name is known.
3380    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3381    pub fn column(index: usize) -> HirScalarExpr {
3382        HirScalarExpr::Column(
3383            ColumnRef {
3384                level: 0,
3385                column: index,
3386            },
3387            TreatAsEqual(None),
3388        )
3389    }
3390
3391    /// Constructs an unnamed column reference.
3392    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3393        HirScalarExpr::Column(cr, TreatAsEqual(None))
3394    }
3395
3396    /// Constructs a named column reference.
3397    /// Names are interned by a `NameManager`.
3398    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3399        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3400    }
3401
3402    pub fn parameter(n: usize) -> HirScalarExpr {
3403        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3404    }
3405
3406    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3407        let col_type = scalar_type.nullable(datum.is_null());
3408        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3409        let row = Row::pack([datum]);
3410        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3411    }
3412
3413    pub fn literal_true() -> HirScalarExpr {
3414        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3415    }
3416
3417    pub fn literal_false() -> HirScalarExpr {
3418        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3419    }
3420
3421    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3422        HirScalarExpr::literal(Datum::Null, scalar_type)
3423    }
3424
3425    pub fn literal_1d_array(
3426        datums: Vec<Datum>,
3427        element_scalar_type: SqlScalarType,
3428    ) -> Result<HirScalarExpr, PlanError> {
3429        let scalar_type = match element_scalar_type {
3430            SqlScalarType::Array(_) => {
3431                sql_bail!("cannot build array from array type");
3432            }
3433            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3434        };
3435
3436        let mut row = Row::default();
3437        row.packer()
3438            .try_push_array(
3439                &[ArrayDimension {
3440                    lower_bound: 1,
3441                    length: datums.len(),
3442                }],
3443                datums,
3444            )
3445            .expect("array constructed to be valid");
3446
3447        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3448    }
3449
3450    pub fn as_literal(&self) -> Option<Datum<'_>> {
3451        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3452            Some(row.unpack_first())
3453        } else {
3454            None
3455        }
3456    }
3457
3458    pub fn is_literal_true(&self) -> bool {
3459        Some(Datum::True) == self.as_literal()
3460    }
3461
3462    pub fn is_literal_false(&self) -> bool {
3463        Some(Datum::False) == self.as_literal()
3464    }
3465
3466    pub fn is_literal_null(&self) -> bool {
3467        Some(Datum::Null) == self.as_literal()
3468    }
3469
3470    /// Return true iff `self` consists only of literals, materializable function calls, and
3471    /// if-else statements.
3472    pub fn is_constant(&self) -> bool {
3473        let mut worklist = vec![self];
3474        while let Some(expr) = worklist.pop() {
3475            match expr {
3476                Self::Literal(..) => {
3477                    // leaf node, do nothing
3478                }
3479                Self::CallUnary { expr, .. } => {
3480                    worklist.push(expr);
3481                }
3482                Self::CallBinary {
3483                    func: _,
3484                    expr1,
3485                    expr2,
3486                    name: _,
3487                } => {
3488                    worklist.push(expr1);
3489                    worklist.push(expr2);
3490                }
3491                Self::CallVariadic {
3492                    func: _,
3493                    exprs,
3494                    name: _,
3495                } => {
3496                    worklist.extend(exprs.iter());
3497                }
3498                // (CallUnmaterializable is not allowed)
3499                Self::If {
3500                    cond,
3501                    then,
3502                    els,
3503                    name: _,
3504                } => {
3505                    worklist.push(cond);
3506                    worklist.push(then);
3507                    worklist.push(els);
3508                }
3509                _ => {
3510                    return false; // Any other node makes `self` non-constant.
3511                }
3512            }
3513        }
3514        true
3515    }
3516
3517    pub fn call_unary(self, func: UnaryFunc) -> Self {
3518        HirScalarExpr::CallUnary {
3519            func,
3520            expr: Box::new(self),
3521            name: NameMetadata::default(),
3522        }
3523    }
3524
3525    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3526        HirScalarExpr::CallBinary {
3527            func: func.into(),
3528            expr1: Box::new(self),
3529            expr2: Box::new(other),
3530            name: NameMetadata::default(),
3531        }
3532    }
3533
3534    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3535        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3536    }
3537
3538    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
3539        HirScalarExpr::CallVariadic {
3540            func: func.into(),
3541            exprs,
3542            name: NameMetadata::default(),
3543        }
3544    }
3545
3546    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3547        HirScalarExpr::If {
3548            cond: Box::new(cond),
3549            then: Box::new(then),
3550            els: Box::new(els),
3551            name: NameMetadata::default(),
3552        }
3553    }
3554
3555    pub fn windowing(expr: WindowExpr) -> Self {
3556        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3557    }
3558
3559    pub fn or(self, other: Self) -> Self {
3560        HirScalarExpr::call_variadic(Or, vec![self, other])
3561    }
3562
3563    pub fn and(self, other: Self) -> Self {
3564        HirScalarExpr::call_variadic(And, vec![self, other])
3565    }
3566
3567    pub fn not(self) -> Self {
3568        self.call_unary(UnaryFunc::Not(func::Not))
3569    }
3570
3571    pub fn call_is_null(self) -> Self {
3572        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3573    }
3574
3575    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3576    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3577        match args.len() {
3578            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3579            1 => args.swap_remove(0),
3580            _ => HirScalarExpr::call_variadic(And, args),
3581        }
3582    }
3583
3584    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3585    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3586        match args.len() {
3587            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3588            1 => args.swap_remove(0),
3589            _ => HirScalarExpr::call_variadic(Or, args),
3590        }
3591    }
3592
3593    pub fn take(&mut self) -> Self {
3594        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3595    }
3596
3597    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3598    /// Visits the column references in this scalar expression.
3599    ///
3600    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3601    /// which will be incremented with each subquery entered and presented to the supplied
3602    /// function `f`.
3603    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3604    where
3605        F: FnMut(usize, &ColumnRef),
3606    {
3607        #[allow(deprecated)]
3608        let _ = self.visit_recursively(depth, &mut |depth: usize,
3609                                                    e: &HirScalarExpr|
3610         -> Result<(), ()> {
3611            if let HirScalarExpr::Column(col, _name) = e {
3612                f(depth, col)
3613            }
3614            Ok(())
3615        });
3616    }
3617
3618    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3619    /// Like `visit_columns`, but permits mutating the column references.
3620    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3621    where
3622        F: FnMut(usize, &mut ColumnRef),
3623    {
3624        #[allow(deprecated)]
3625        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3626                                                        e: &mut HirScalarExpr|
3627         -> Result<(), ()> {
3628            if let HirScalarExpr::Column(col, _name) = e {
3629                f(depth, col)
3630            }
3631            Ok(())
3632        });
3633    }
3634
3635    /// Visits those column references in this scalar expression that refer to the root
3636    /// level. These include column references that are at the root level, as well as column
3637    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3638    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3639    /// "root level" to be `self`'s level.)
3640    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3641    where
3642        F: FnMut(usize),
3643    {
3644        #[allow(deprecated)]
3645        let _ = self.visit_recursively(0, &mut |depth: usize,
3646                                                e: &HirScalarExpr|
3647         -> Result<(), ()> {
3648            if let HirScalarExpr::Column(col, _name) = e {
3649                if col.level == depth {
3650                    f(col.column)
3651                }
3652            }
3653            Ok(())
3654        });
3655    }
3656
3657    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3658    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3659    where
3660        F: FnMut(&mut usize),
3661    {
3662        #[allow(deprecated)]
3663        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3664                                                    e: &mut HirScalarExpr|
3665         -> Result<(), ()> {
3666            if let HirScalarExpr::Column(col, _name) = e {
3667                if col.level == depth {
3668                    f(&mut col.column)
3669                }
3670            }
3671            Ok(())
3672        });
3673    }
3674
3675    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3676    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3677    /// in them. It takes the current depth of the expression and increases it when
3678    /// entering a subquery.
3679    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3680    where
3681        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3682    {
3683        match self {
3684            HirScalarExpr::Literal(..)
3685            | HirScalarExpr::Parameter(..)
3686            | HirScalarExpr::CallUnmaterializable(..)
3687            | HirScalarExpr::Column(..) => (),
3688            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3689            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3690                expr1.visit_recursively(depth, f)?;
3691                expr2.visit_recursively(depth, f)?;
3692            }
3693            HirScalarExpr::CallVariadic { exprs, .. } => {
3694                for expr in exprs {
3695                    expr.visit_recursively(depth, f)?;
3696                }
3697            }
3698            HirScalarExpr::If {
3699                cond,
3700                then,
3701                els,
3702                name: _,
3703            } => {
3704                cond.visit_recursively(depth, f)?;
3705                then.visit_recursively(depth, f)?;
3706                els.visit_recursively(depth, f)?;
3707            }
3708            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3709                #[allow(deprecated)]
3710                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3711                    e.visit_recursively(depth, f)
3712                })?;
3713            }
3714            HirScalarExpr::Windowing(expr, _name) => {
3715                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3716            }
3717        }
3718        f(depth, self)
3719    }
3720
3721    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3722    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3723    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3724    where
3725        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3726    {
3727        match self {
3728            HirScalarExpr::Literal(..)
3729            | HirScalarExpr::Parameter(..)
3730            | HirScalarExpr::CallUnmaterializable(..)
3731            | HirScalarExpr::Column(..) => (),
3732            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3733            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3734                expr1.visit_recursively_mut(depth, f)?;
3735                expr2.visit_recursively_mut(depth, f)?;
3736            }
3737            HirScalarExpr::CallVariadic { exprs, .. } => {
3738                for expr in exprs {
3739                    expr.visit_recursively_mut(depth, f)?;
3740                }
3741            }
3742            HirScalarExpr::If {
3743                cond,
3744                then,
3745                els,
3746                name: _,
3747            } => {
3748                cond.visit_recursively_mut(depth, f)?;
3749                then.visit_recursively_mut(depth, f)?;
3750                els.visit_recursively_mut(depth, f)?;
3751            }
3752            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3753                #[allow(deprecated)]
3754                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3755                    e.visit_recursively_mut(depth, f)
3756                })?;
3757            }
3758            HirScalarExpr::Windowing(expr, _name) => {
3759                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3760            }
3761        }
3762        f(depth, self)
3763    }
3764
3765    /// Attempts to simplify self into a literal.
3766    ///
3767    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3768    /// an evaluation error occurs during simplification, or if self contains
3769    /// - a subquery
3770    /// - a column reference to an outer level
3771    /// - a parameter
3772    /// - a window function call
3773    fn simplify_to_literal(self) -> Option<Row> {
3774        let mut expr = self
3775            .lower_uncorrelated(crate::plan::lowering::Config::default())
3776            .ok()?;
3777        // Using MIR evaluation with repr types is fine here: the
3778        // result is an untyped Row, so any intermediate type
3779        // canonicalization is discarded.
3780        expr.reduce(&[]);
3781        match expr {
3782            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3783            _ => None,
3784        }
3785    }
3786
3787    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3788    /// or an evaluation error occurs during simplification), it returns
3789    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3790    ///
3791    /// The returned error is an _internal_ error if the expression contains
3792    /// - a subquery
3793    /// - a column reference to an outer level
3794    /// - a parameter
3795    /// - a window function call
3796    ///
3797    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3798    /// msg.
3799    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3800        let mut expr = self
3801            .lower_uncorrelated(crate::plan::lowering::Config::default())
3802            .map_err(|err| {
3803                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3804            })?;
3805        // Using MIR evaluation with repr types is fine here: the
3806        // result is an untyped Row, so any intermediate type
3807        // canonicalization is discarded.
3808        expr.reduce(&[]);
3809        match expr {
3810            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3811            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3812                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3813            ),
3814            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3815                "Not a constant".to_string(),
3816            )),
3817        }
3818    }
3819
3820    /// Attempts to simplify this expression to a literal 64-bit integer.
3821    ///
3822    /// Returns `None` if this expression cannot be simplified, e.g. because it
3823    /// contains non-literal values.
3824    ///
3825    /// # Panics
3826    ///
3827    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3828    pub fn into_literal_int64(self) -> Option<i64> {
3829        self.simplify_to_literal().and_then(|row| {
3830            let datum = row.unpack_first();
3831            if datum.is_null() {
3832                None
3833            } else {
3834                Some(datum.unwrap_int64())
3835            }
3836        })
3837    }
3838
3839    /// Attempts to simplify this expression to a literal string.
3840    ///
3841    /// Returns `None` if this expression cannot be simplified, e.g. because it
3842    /// contains non-literal values.
3843    ///
3844    /// # Panics
3845    ///
3846    /// Panics if this expression does not have type [`SqlScalarType::String`].
3847    pub fn into_literal_string(self) -> Option<String> {
3848        self.simplify_to_literal().and_then(|row| {
3849            let datum = row.unpack_first();
3850            if datum.is_null() {
3851                None
3852            } else {
3853                Some(datum.unwrap_str().to_owned())
3854            }
3855        })
3856    }
3857
3858    /// Attempts to simplify this expression to a literal MzTimestamp.
3859    ///
3860    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3861    /// simplified, e.g. because it contains non-literal values or a cast fails.
3862    ///
3863    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3864    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3865    /// See `try_into_literal_int64` as an example.
3866    ///
3867    /// # Panics
3868    ///
3869    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3870    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3871        self.simplify_to_literal().and_then(|row| {
3872            let datum = row.unpack_first();
3873            if datum.is_null() {
3874                None
3875            } else {
3876                Some(datum.unwrap_mz_timestamp())
3877            }
3878        })
3879    }
3880
3881    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3882    /// returns it as an i64.
3883    ///
3884    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3885    /// - it's not a constant expression (as determined by `is_constant`)
3886    /// - evaluates to null
3887    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3888    ///
3889    /// # Panics
3890    ///
3891    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3892    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3893        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3894        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3895        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3896        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3897        // preconditions also of all the other into_literal_... functions.)
3898        if !self.is_constant() {
3899            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3900                "Expected a constant expression, got {}",
3901                self
3902            )));
3903        }
3904        self.clone()
3905            .simplify_to_literal_with_result()
3906            .and_then(|row| {
3907                let datum = row.unpack_first();
3908                if datum.is_null() {
3909                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3910                        "Expected an expression that evaluates to a non-null value, got {}",
3911                        self
3912                    )))
3913                } else {
3914                    Ok(datum.unwrap_int64())
3915                }
3916            })
3917    }
3918
3919    pub fn contains_parameters(&self) -> bool {
3920        let mut contains_parameters = false;
3921        #[allow(deprecated)]
3922        let _ = self.visit_recursively(0, &mut |_depth: usize,
3923                                                expr: &HirScalarExpr|
3924         -> Result<(), ()> {
3925            if let HirScalarExpr::Parameter(..) = expr {
3926                contains_parameters = true;
3927            }
3928            Ok(())
3929        });
3930        contains_parameters
3931    }
3932}
3933
3934/// Yields the direct scalar children of `self`. Stops at `Exists` / `Select`:
3935/// scalars inside subquery bodies are not surfaced. The asymmetry with
3936/// `VisitChildren<Self> for HirRelationExpr` (which does see through scalars
3937/// into subqueries) is what avoids the mutual recursion warned about on the
3938/// [`VisitChildren`] trait.
3939impl VisitChildren<Self> for HirScalarExpr {
3940    fn visit_children<F>(&self, mut f: F)
3941    where
3942        F: FnMut(&Self),
3943    {
3944        use HirScalarExpr::*;
3945        match self {
3946            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3947            CallUnary { expr, .. } => f(expr),
3948            CallBinary { expr1, expr2, .. } => {
3949                f(expr1);
3950                f(expr2);
3951            }
3952            CallVariadic { exprs, .. } => {
3953                for expr in exprs {
3954                    f(expr);
3955                }
3956            }
3957            If {
3958                cond,
3959                then,
3960                els,
3961                name: _,
3962            } => {
3963                f(cond);
3964                f(then);
3965                f(els);
3966            }
3967            Exists(..) | Select(..) => (),
3968            Windowing(expr, _name) => expr.visit_children(f),
3969        }
3970    }
3971
3972    fn visit_mut_children<F>(&mut self, mut f: F)
3973    where
3974        F: FnMut(&mut Self),
3975    {
3976        use HirScalarExpr::*;
3977        match self {
3978            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3979            CallUnary { expr, .. } => f(expr),
3980            CallBinary { expr1, expr2, .. } => {
3981                f(expr1);
3982                f(expr2);
3983            }
3984            CallVariadic { exprs, .. } => {
3985                for expr in exprs {
3986                    f(expr);
3987                }
3988            }
3989            If {
3990                cond,
3991                then,
3992                els,
3993                name: _,
3994            } => {
3995                f(cond);
3996                f(then);
3997                f(els);
3998            }
3999            Exists(..) | Select(..) => (),
4000            Windowing(expr, _name) => expr.visit_mut_children(f),
4001        }
4002    }
4003
4004    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
4005    where
4006        F: FnMut(&Self) -> Result<(), E>,
4007        E: From<RecursionLimitError>,
4008    {
4009        use HirScalarExpr::*;
4010        match self {
4011            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4012            CallUnary { expr, .. } => f(expr)?,
4013            CallBinary { expr1, expr2, .. } => {
4014                f(expr1)?;
4015                f(expr2)?;
4016            }
4017            CallVariadic { exprs, .. } => {
4018                for expr in exprs {
4019                    f(expr)?;
4020                }
4021            }
4022            If {
4023                cond,
4024                then,
4025                els,
4026                name: _,
4027            } => {
4028                f(cond)?;
4029                f(then)?;
4030                f(els)?;
4031            }
4032            Exists(..) | Select(..) => (),
4033            Windowing(expr, _name) => expr.try_visit_children(f)?,
4034        }
4035        Ok(())
4036    }
4037
4038    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
4039    where
4040        F: FnMut(&mut Self) -> Result<(), E>,
4041        E: From<RecursionLimitError>,
4042    {
4043        use HirScalarExpr::*;
4044        match self {
4045            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4046            CallUnary { expr, .. } => f(expr)?,
4047            CallBinary { expr1, expr2, .. } => {
4048                f(expr1)?;
4049                f(expr2)?;
4050            }
4051            CallVariadic { exprs, .. } => {
4052                for expr in exprs {
4053                    f(expr)?;
4054                }
4055            }
4056            If {
4057                cond,
4058                then,
4059                els,
4060                name: _,
4061            } => {
4062                f(cond)?;
4063                f(then)?;
4064                f(els)?;
4065            }
4066            Exists(..) | Select(..) => (),
4067            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
4068        }
4069        Ok(())
4070    }
4071}
4072
4073/// Yields the immediate `HirRelationExpr` children of `self` (the bodies of
4074/// `Exists` / `Select`); does not descend into them.
4075impl VisitChildren<HirRelationExpr> for HirScalarExpr {
4076    fn visit_children<F>(&self, mut f: F)
4077    where
4078        F: FnMut(&HirRelationExpr),
4079    {
4080        use HirScalarExpr::*;
4081        match self {
4082            Column(..)
4083            | Parameter(..)
4084            | Literal(..)
4085            | CallUnmaterializable(..)
4086            | CallUnary { .. }
4087            | CallBinary { .. }
4088            | CallVariadic { .. }
4089            | If { .. }
4090            | Windowing(..) => (),
4091            Exists(expr, _name) | Select(expr, _name) => f(expr),
4092        }
4093    }
4094
4095    fn visit_mut_children<F>(&mut self, mut f: F)
4096    where
4097        F: FnMut(&mut HirRelationExpr),
4098    {
4099        use HirScalarExpr::*;
4100        match self {
4101            Column(..)
4102            | Parameter(..)
4103            | Literal(..)
4104            | CallUnmaterializable(..)
4105            | CallUnary { .. }
4106            | CallBinary { .. }
4107            | CallVariadic { .. }
4108            | If { .. }
4109            | Windowing(..) => (),
4110            Exists(expr, _name) | Select(expr, _name) => f(expr),
4111        }
4112    }
4113
4114    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
4115    where
4116        F: FnMut(&HirRelationExpr) -> Result<(), E>,
4117        E: From<RecursionLimitError>,
4118    {
4119        use HirScalarExpr::*;
4120        match self {
4121            Column(..)
4122            | Parameter(..)
4123            | Literal(..)
4124            | CallUnmaterializable(..)
4125            | CallUnary { .. }
4126            | CallBinary { .. }
4127            | CallVariadic { .. }
4128            | If { .. }
4129            | Windowing(..) => (),
4130            Exists(expr, _name) | Select(expr, _name) => f(expr)?,
4131        }
4132        Ok(())
4133    }
4134
4135    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
4136    where
4137        F: FnMut(&mut HirRelationExpr) -> Result<(), E>,
4138        E: From<RecursionLimitError>,
4139    {
4140        use HirScalarExpr::*;
4141        match self {
4142            Column(..)
4143            | Parameter(..)
4144            | Literal(..)
4145            | CallUnmaterializable(..)
4146            | CallUnary { .. }
4147            | CallBinary { .. }
4148            | CallVariadic { .. }
4149            | If { .. }
4150            | Windowing(..) => (),
4151            Exists(expr, _name) | Select(expr, _name) => f(expr)?,
4152        }
4153        Ok(())
4154    }
4155}
4156
4157impl AbstractExpr for HirScalarExpr {
4158    type Type = SqlColumnType;
4159
4160    fn typ(
4161        &self,
4162        outers: &[SqlRelationType],
4163        inner: &SqlRelationType,
4164        params: &BTreeMap<usize, SqlScalarType>,
4165    ) -> Self::Type {
4166        stack::maybe_grow(|| match self {
4167            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
4168                if *level == 0 {
4169                    inner.column_types[*column].clone()
4170                } else {
4171                    outers[*level - 1].column_types[*column].clone()
4172                }
4173            }
4174            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
4175            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
4176            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(),
4177            HirScalarExpr::CallUnary {
4178                expr,
4179                func,
4180                name: _,
4181            } => func.output_sql_type(expr.typ(outers, inner, params)),
4182            HirScalarExpr::CallBinary {
4183                expr1,
4184                expr2,
4185                func,
4186                name: _,
4187            } => func.output_sql_type(&[
4188                expr1.typ(outers, inner, params),
4189                expr2.typ(outers, inner, params),
4190            ]),
4191            HirScalarExpr::CallVariadic {
4192                exprs,
4193                func,
4194                name: _,
4195            } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
4196            HirScalarExpr::If {
4197                cond: _,
4198                then,
4199                els,
4200                name: _,
4201            } => {
4202                let then_type = then.typ(outers, inner, params);
4203                let else_type = els.typ(outers, inner, params);
4204                then_type.sql_union(&else_type).unwrap() // HIR deliberately not using `union`
4205            }
4206            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
4207            HirScalarExpr::Select(expr, _name) => {
4208                let mut outers = outers.to_vec();
4209                outers.insert(0, inner.clone());
4210                expr.typ(&outers, params)
4211                    .column_types
4212                    .into_element()
4213                    .nullable(true)
4214            }
4215            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
4216        })
4217    }
4218}
4219
4220impl AggregateExpr {
4221    pub fn typ(
4222        &self,
4223        outers: &[SqlRelationType],
4224        inner: &SqlRelationType,
4225        params: &BTreeMap<usize, SqlScalarType>,
4226    ) -> SqlColumnType {
4227        self.func
4228            .output_sql_type(self.expr.typ(outers, inner, params))
4229    }
4230
4231    /// Returns whether the expression is COUNT(*) or not.  Note that
4232    /// when we define the count builtin in sql::func, we convert
4233    /// COUNT(*) to COUNT(true), making it indistinguishable from
4234    /// literal COUNT(true), but we prefer to consider this as the
4235    /// former.
4236    ///
4237    /// (MIR has the same `is_count_asterisk`.)
4238    pub fn is_count_asterisk(&self) -> bool {
4239        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
4240    }
4241}