Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50    type Relation = HirRelationExpr;
51    type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56        if *all {
57            let rhs = rhs.negate();
58            HirRelationExpr::union(lhs, rhs).threshold()
59        } else {
60            let lhs = lhs.distinct();
61            let rhs = rhs.distinct().negate();
62            HirRelationExpr::union(lhs, rhs).threshold()
63        }
64    }
65
66    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67        let mut result = None;
68
69        use HirRelationExpr::*;
70        if let Threshold { input } = expr {
71            if let Union { base: lhs, inputs } = input.as_ref() {
72                if let [rhs] = &inputs[..] {
73                    if let Negate { input: rhs } = rhs {
74                        match (lhs.as_ref(), rhs.as_ref()) {
75                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
76                                let all = false;
77                                let lhs = lhs.as_ref();
78                                let rhs = rhs.as_ref();
79                                result = Some(Except { all, lhs, rhs })
80                            }
81                            (lhs, rhs) => {
82                                let all = true;
83                                result = Some(Except { all, lhs, rhs })
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        result
92    }
93}
94
95#[derive(
96    Debug,
97    Clone,
98    PartialEq,
99    Eq,
100    PartialOrd,
101    Ord,
102    Hash,
103    Serialize,
104    Deserialize
105)]
106/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
107pub enum HirRelationExpr {
108    Constant {
109        rows: Vec<Row>,
110        typ: SqlRelationType,
111    },
112    Get {
113        id: mz_expr::Id,
114        typ: SqlRelationType,
115    },
116    /// Mutually recursive CTE
117    LetRec {
118        /// Maximum number of iterations to evaluate. If None, then there is no limit.
119        limit: Option<LetRecLimit>,
120        /// List of bindings all of which are in scope of each other.
121        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
122        /// Result of the AST node.
123        body: Box<HirRelationExpr>,
124    },
125    /// CTE
126    Let {
127        name: String,
128        /// The identifier to be used in `Get` variants to retrieve `value`.
129        id: mz_expr::LocalId,
130        /// The collection to be bound to `name`.
131        value: Box<HirRelationExpr>,
132        /// The result of the `Let`, evaluated with `name` bound to `value`.
133        body: Box<HirRelationExpr>,
134    },
135    Project {
136        input: Box<HirRelationExpr>,
137        outputs: Vec<usize>,
138    },
139    Map {
140        input: Box<HirRelationExpr>,
141        scalars: Vec<HirScalarExpr>,
142    },
143    CallTable {
144        func: TableFunc,
145        exprs: Vec<HirScalarExpr>,
146    },
147    Filter {
148        input: Box<HirRelationExpr>,
149        predicates: Vec<HirScalarExpr>,
150    },
151    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
152    /// joins away into more primitive exprs
153    Join {
154        left: Box<HirRelationExpr>,
155        right: Box<HirRelationExpr>,
156        on: HirScalarExpr,
157        kind: JoinKind,
158    },
159    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
160    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
161    /// rows
162    Reduce {
163        input: Box<HirRelationExpr>,
164        group_key: Vec<usize>,
165        aggregates: Vec<AggregateExpr>,
166        expected_group_size: Option<u64>,
167    },
168    Distinct {
169        input: Box<HirRelationExpr>,
170    },
171    /// Groups and orders within each group, limiting output.
172    TopK {
173        /// The source collection.
174        input: Box<HirRelationExpr>,
175        /// Column indices used to form groups.
176        group_key: Vec<usize>,
177        /// Column indices used to order rows within groups.
178        order_key: Vec<ColumnOrder>,
179        /// Number of records to retain.
180        /// It is of SqlScalarType::Int64.
181        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
182        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
183        /// is better for Postgres compat. This is because if there is a $1 here, then when external
184        /// tools `describe` the prepared statement, they discover this type. If what they find
185        /// were UInt64, then they might have trouble calling the prepared statement, because the
186        /// unsigned types are non-standard, and also don't exist even in Postgres.)
187        limit: Option<HirScalarExpr>,
188        /// Number of records to skip.
189        /// It is of SqlScalarType::Int64.
190        /// This can contain parameters at first, but by the time we reach lowering, this should
191        /// already be simply a Literal.
192        offset: HirScalarExpr,
193        /// User-supplied hint: how many rows will have the same group key.
194        expected_group_size: Option<u64>,
195    },
196    Negate {
197        input: Box<HirRelationExpr>,
198    },
199    /// Keep rows from a dataflow where the row counts are positive.
200    Threshold {
201        input: Box<HirRelationExpr>,
202    },
203    Union {
204        base: Box<HirRelationExpr>,
205        inputs: Vec<HirRelationExpr>,
206    },
207}
208
209/// Stored column metadata.
210pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
211
212#[derive(
213    Debug,
214    Clone,
215    PartialEq,
216    Eq,
217    PartialOrd,
218    Ord,
219    Hash,
220    Serialize,
221    Deserialize
222)]
223/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
224pub enum HirScalarExpr {
225    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
226    /// variable could refer to a column of the current input, or to a column of an outer relation.
227    /// We use ColumnRef to denote the difference.
228    Column(ColumnRef, NameMetadata),
229    Parameter(usize, NameMetadata),
230    Literal(Row, SqlColumnType, NameMetadata),
231    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
232    CallUnary {
233        func: UnaryFunc,
234        expr: Box<HirScalarExpr>,
235        name: NameMetadata,
236    },
237    CallBinary {
238        func: BinaryFunc,
239        expr1: Box<HirScalarExpr>,
240        expr2: Box<HirScalarExpr>,
241        name: NameMetadata,
242    },
243    CallVariadic {
244        func: VariadicFunc,
245        exprs: Vec<HirScalarExpr>,
246        name: NameMetadata,
247    },
248    If {
249        cond: Box<HirScalarExpr>,
250        then: Box<HirScalarExpr>,
251        els: Box<HirScalarExpr>,
252        name: NameMetadata,
253    },
254    /// Returns true if `expr` returns any rows
255    Exists(Box<HirRelationExpr>, NameMetadata),
256    /// Given `expr` with arity 1. If expr returns:
257    /// * 0 rows, return NULL
258    /// * 1 row, return the value of that row
259    /// * >1 rows, we return an error
260    Select(Box<HirRelationExpr>, NameMetadata),
261    Windowing(WindowExpr, NameMetadata),
262}
263
264#[derive(
265    Debug,
266    Clone,
267    PartialEq,
268    Eq,
269    PartialOrd,
270    Ord,
271    Hash,
272    Serialize,
273    Deserialize
274)]
275/// Represents the invocation of a window function over an optional partitioning with an optional
276/// order.
277pub struct WindowExpr {
278    pub func: WindowExprType,
279    pub partition_by: Vec<HirScalarExpr>,
280    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
281    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
282    ///    above,
283    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
284    /// These are separated because they are used in different places: the outer `order_by` is used
285    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
286    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
287    /// (`WindowExpr` exists only in HIR, but not in MIR.)
288    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
289    /// lowering, and not to original input columns.
290    pub order_by: Vec<HirScalarExpr>,
291}
292
293impl WindowExpr {
294    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
295    where
296        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
297    {
298        #[allow(deprecated)]
299        self.func.visit_expressions(f)?;
300        for expr in self.partition_by.iter() {
301            f(expr)?;
302        }
303        for expr in self.order_by.iter() {
304            f(expr)?;
305        }
306        Ok(())
307    }
308
309    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
310    where
311        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
312    {
313        #[allow(deprecated)]
314        self.func.visit_expressions_mut(f)?;
315        for expr in self.partition_by.iter_mut() {
316            f(expr)?;
317        }
318        for expr in self.order_by.iter_mut() {
319            f(expr)?;
320        }
321        Ok(())
322    }
323}
324
325impl VisitChildren<HirScalarExpr> for WindowExpr {
326    fn visit_children<F>(&self, mut f: F)
327    where
328        F: FnMut(&HirScalarExpr),
329    {
330        self.func.visit_children(&mut f);
331        for expr in self.partition_by.iter() {
332            f(expr);
333        }
334        for expr in self.order_by.iter() {
335            f(expr);
336        }
337    }
338
339    fn visit_mut_children<F>(&mut self, mut f: F)
340    where
341        F: FnMut(&mut HirScalarExpr),
342    {
343        self.func.visit_mut_children(&mut f);
344        for expr in self.partition_by.iter_mut() {
345            f(expr);
346        }
347        for expr in self.order_by.iter_mut() {
348            f(expr);
349        }
350    }
351
352    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
353    where
354        F: FnMut(&HirScalarExpr) -> Result<(), E>,
355        E: From<RecursionLimitError>,
356    {
357        self.func.try_visit_children(&mut f)?;
358        for expr in self.partition_by.iter() {
359            f(expr)?;
360        }
361        for expr in self.order_by.iter() {
362            f(expr)?;
363        }
364        Ok(())
365    }
366
367    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
368    where
369        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
370        E: From<RecursionLimitError>,
371    {
372        self.func.try_visit_mut_children(&mut f)?;
373        for expr in self.partition_by.iter_mut() {
374            f(expr)?;
375        }
376        for expr in self.order_by.iter_mut() {
377            f(expr)?;
378        }
379        Ok(())
380    }
381}
382
383#[derive(
384    Debug,
385    Clone,
386    PartialEq,
387    Eq,
388    PartialOrd,
389    Ord,
390    Hash,
391    Serialize,
392    Deserialize
393)]
394/// A window function with its parameters.
395///
396/// There are three types of window functions:
397/// - scalar window functions, which return a different scalar value for each
398///   row within a partition that depends exclusively on the position of the row
399///   within the partition;
400/// - value window functions, which return a scalar value for each row within a
401///   partition that might be computed based on a single row, which is usually not
402///   the current row (e.g., previous or following row; first or last row of the
403///   partition);
404/// - aggregate window functions, which compute a traditional aggregation as a
405///   window function (e.g. `sum(x) OVER (...)`).
406///   (Aggregate window  functions can in some cases be computed by joining the
407///   input relation with a reduction over the same relation that computes the
408///   aggregation using the partition key as its grouping key, but we don't
409///   automatically do this currently.)
410pub enum WindowExprType {
411    Scalar(ScalarWindowExpr),
412    Value(ValueWindowExpr),
413    Aggregate(AggregateWindowExpr),
414}
415
416impl WindowExprType {
417    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
418    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
419    where
420        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
421    {
422        #[allow(deprecated)]
423        match self {
424            Self::Scalar(expr) => expr.visit_expressions(f),
425            Self::Value(expr) => expr.visit_expressions(f),
426            Self::Aggregate(expr) => expr.visit_expressions(f),
427        }
428    }
429
430    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
431    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
432    where
433        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
434    {
435        #[allow(deprecated)]
436        match self {
437            Self::Scalar(expr) => expr.visit_expressions_mut(f),
438            Self::Value(expr) => expr.visit_expressions_mut(f),
439            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
440        }
441    }
442
443    fn typ(
444        &self,
445        outers: &[SqlRelationType],
446        inner: &SqlRelationType,
447        params: &BTreeMap<usize, SqlScalarType>,
448    ) -> SqlColumnType {
449        match self {
450            Self::Scalar(expr) => expr.typ(outers, inner, params),
451            Self::Value(expr) => expr.typ(outers, inner, params),
452            Self::Aggregate(expr) => expr.typ(outers, inner, params),
453        }
454    }
455}
456
457impl VisitChildren<HirScalarExpr> for WindowExprType {
458    fn visit_children<F>(&self, f: F)
459    where
460        F: FnMut(&HirScalarExpr),
461    {
462        match self {
463            Self::Scalar(_) => (),
464            Self::Value(expr) => expr.visit_children(f),
465            Self::Aggregate(expr) => expr.visit_children(f),
466        }
467    }
468
469    fn visit_mut_children<F>(&mut self, f: F)
470    where
471        F: FnMut(&mut HirScalarExpr),
472    {
473        match self {
474            Self::Scalar(_) => (),
475            Self::Value(expr) => expr.visit_mut_children(f),
476            Self::Aggregate(expr) => expr.visit_mut_children(f),
477        }
478    }
479
480    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
481    where
482        F: FnMut(&HirScalarExpr) -> Result<(), E>,
483        E: From<RecursionLimitError>,
484    {
485        match self {
486            Self::Scalar(_) => Ok(()),
487            Self::Value(expr) => expr.try_visit_children(f),
488            Self::Aggregate(expr) => expr.try_visit_children(f),
489        }
490    }
491
492    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
493    where
494        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
495        E: From<RecursionLimitError>,
496    {
497        match self {
498            Self::Scalar(_) => Ok(()),
499            Self::Value(expr) => expr.try_visit_mut_children(f),
500            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
501        }
502    }
503}
504
505#[derive(
506    Debug,
507    Clone,
508    PartialEq,
509    Eq,
510    PartialOrd,
511    Ord,
512    Hash,
513    Serialize,
514    Deserialize
515)]
516pub struct ScalarWindowExpr {
517    pub func: ScalarWindowFunc,
518    pub order_by: Vec<ColumnOrder>,
519}
520
521impl ScalarWindowExpr {
522    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
523    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
524    where
525        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
526    {
527        match self.func {
528            ScalarWindowFunc::RowNumber => {}
529            ScalarWindowFunc::Rank => {}
530            ScalarWindowFunc::DenseRank => {}
531        }
532        Ok(())
533    }
534
535    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
536    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
537    where
538        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
539    {
540        match self.func {
541            ScalarWindowFunc::RowNumber => {}
542            ScalarWindowFunc::Rank => {}
543            ScalarWindowFunc::DenseRank => {}
544        }
545        Ok(())
546    }
547
548    fn typ(
549        &self,
550        _outers: &[SqlRelationType],
551        _inner: &SqlRelationType,
552        _params: &BTreeMap<usize, SqlScalarType>,
553    ) -> SqlColumnType {
554        self.func.output_type()
555    }
556
557    pub fn into_expr(self) -> mz_expr::AggregateFunc {
558        match self.func {
559            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
560                order_by: self.order_by,
561            },
562            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
563                order_by: self.order_by,
564            },
565            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
566                order_by: self.order_by,
567            },
568        }
569    }
570}
571
572#[derive(
573    Debug,
574    Clone,
575    PartialEq,
576    Eq,
577    PartialOrd,
578    Ord,
579    Hash,
580    Serialize,
581    Deserialize
582)]
583/// Scalar Window functions
584pub enum ScalarWindowFunc {
585    RowNumber,
586    Rank,
587    DenseRank,
588}
589
590impl Display for ScalarWindowFunc {
591    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
592        match self {
593            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
594            ScalarWindowFunc::Rank => write!(f, "rank"),
595            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
596        }
597    }
598}
599
600impl ScalarWindowFunc {
601    pub fn output_type(&self) -> SqlColumnType {
602        match self {
603            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
604            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
605            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
606        }
607    }
608}
609
610#[derive(
611    Debug,
612    Clone,
613    PartialEq,
614    Eq,
615    PartialOrd,
616    Ord,
617    Hash,
618    Serialize,
619    Deserialize
620)]
621pub struct ValueWindowExpr {
622    pub func: ValueWindowFunc,
623    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
624    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
625    /// e.g., `row(#1, 3, null)`.
626    /// If it's a fused window function, then the arguments of each of the constituent function
627    /// calls are wrapped in an outer record.
628    pub args: Box<HirScalarExpr>,
629    /// See comment on `WindowExpr::order_by`.
630    pub order_by: Vec<ColumnOrder>,
631    pub window_frame: WindowFrame,
632    pub ignore_nulls: bool,
633}
634
635impl Display for ValueWindowFunc {
636    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
637        match self {
638            ValueWindowFunc::Lag => write!(f, "lag"),
639            ValueWindowFunc::Lead => write!(f, "lead"),
640            ValueWindowFunc::FirstValue => write!(f, "first_value"),
641            ValueWindowFunc::LastValue => write!(f, "last_value"),
642            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
643        }
644    }
645}
646
647impl ValueWindowExpr {
648    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
649    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
650    where
651        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
652    {
653        f(&self.args)
654    }
655
656    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
657    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
658    where
659        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
660    {
661        f(&mut self.args)
662    }
663
664    fn typ(
665        &self,
666        outers: &[SqlRelationType],
667        inner: &SqlRelationType,
668        params: &BTreeMap<usize, SqlScalarType>,
669    ) -> SqlColumnType {
670        self.func.output_type(self.args.typ(outers, inner, params))
671    }
672
673    /// Converts into `mz_expr::AggregateFunc`.
674    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
675        (
676            self.args,
677            self.func
678                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
679        )
680    }
681}
682
683impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
684    fn visit_children<F>(&self, mut f: F)
685    where
686        F: FnMut(&HirScalarExpr),
687    {
688        f(&self.args)
689    }
690
691    fn visit_mut_children<F>(&mut self, mut f: F)
692    where
693        F: FnMut(&mut HirScalarExpr),
694    {
695        f(&mut self.args)
696    }
697
698    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
699    where
700        F: FnMut(&HirScalarExpr) -> Result<(), E>,
701        E: From<RecursionLimitError>,
702    {
703        f(&self.args)
704    }
705
706    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
707    where
708        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
709        E: From<RecursionLimitError>,
710    {
711        f(&mut self.args)
712    }
713}
714
715#[derive(
716    Debug,
717    Clone,
718    PartialEq,
719    Eq,
720    PartialOrd,
721    Ord,
722    Hash,
723    Serialize,
724    Deserialize
725)]
726/// Value Window functions
727pub enum ValueWindowFunc {
728    Lag,
729    Lead,
730    FirstValue,
731    LastValue,
732    Fused(Vec<ValueWindowFunc>),
733}
734
735impl ValueWindowFunc {
736    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
737        match self {
738            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
739                // The input is a (value, offset, default) record, so extract the type of the first arg
740                input_type.scalar_type.unwrap_record_element_type()[0]
741                    .clone()
742                    .nullable(true)
743            }
744            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
745                input_type.scalar_type.nullable(true)
746            }
747            ValueWindowFunc::Fused(funcs) => {
748                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
749                SqlScalarType::Record {
750                    fields: funcs
751                        .iter()
752                        .zip_eq(input_types)
753                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
754                        .collect(),
755                    custom_id: None,
756                }
757                .nullable(false)
758            }
759        }
760    }
761
762    pub fn into_expr(
763        self,
764        order_by: Vec<ColumnOrder>,
765        window_frame: WindowFrame,
766        ignore_nulls: bool,
767    ) -> mz_expr::AggregateFunc {
768        match self {
769            // Lag and Lead are fundamentally the same function, just with opposite directions
770            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
771                order_by,
772                lag_lead: mz_expr::LagLeadType::Lag,
773                ignore_nulls,
774            },
775            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
776                order_by,
777                lag_lead: mz_expr::LagLeadType::Lead,
778                ignore_nulls,
779            },
780            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
781                order_by,
782                window_frame,
783            },
784            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
785                order_by,
786                window_frame,
787            },
788            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
789                funcs: funcs
790                    .into_iter()
791                    .map(|func| {
792                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
793                    })
794                    .collect(),
795                order_by,
796            },
797        }
798    }
799}
800
801#[derive(
802    Debug,
803    Clone,
804    PartialEq,
805    Eq,
806    PartialOrd,
807    Ord,
808    Hash,
809    Serialize,
810    Deserialize
811)]
812pub struct AggregateWindowExpr {
813    pub aggregate_expr: AggregateExpr,
814    pub order_by: Vec<ColumnOrder>,
815    pub window_frame: WindowFrame,
816}
817
818impl AggregateWindowExpr {
819    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
820    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
821    where
822        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
823    {
824        f(&self.aggregate_expr.expr)
825    }
826
827    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
828    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
829    where
830        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
831    {
832        f(&mut self.aggregate_expr.expr)
833    }
834
835    fn typ(
836        &self,
837        outers: &[SqlRelationType],
838        inner: &SqlRelationType,
839        params: &BTreeMap<usize, SqlScalarType>,
840    ) -> SqlColumnType {
841        self.aggregate_expr
842            .func
843            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
844    }
845
846    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
847        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
848            (
849                self.aggregate_expr.expr,
850                FusedWindowAggregate {
851                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
852                    order_by: self.order_by,
853                    window_frame: self.window_frame,
854                },
855            )
856        } else {
857            (
858                self.aggregate_expr.expr,
859                WindowAggregate {
860                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
861                    order_by: self.order_by,
862                    window_frame: self.window_frame,
863                },
864            )
865        }
866    }
867}
868
869impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
870    fn visit_children<F>(&self, mut f: F)
871    where
872        F: FnMut(&HirScalarExpr),
873    {
874        f(&self.aggregate_expr.expr)
875    }
876
877    fn visit_mut_children<F>(&mut self, mut f: F)
878    where
879        F: FnMut(&mut HirScalarExpr),
880    {
881        f(&mut self.aggregate_expr.expr)
882    }
883
884    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
885    where
886        F: FnMut(&HirScalarExpr) -> Result<(), E>,
887        E: From<RecursionLimitError>,
888    {
889        f(&self.aggregate_expr.expr)
890    }
891
892    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
893    where
894        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
895        E: From<RecursionLimitError>,
896    {
897        f(&mut self.aggregate_expr.expr)
898    }
899}
900
901/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
902/// determined. Several SQL expressions can be freely coerced based upon where
903/// in the expression tree they appear. For example, the string literal '42'
904/// will be automatically coerced to the integer 42 if used in a numeric
905/// context:
906///
907/// ```sql
908/// SELECT '42' + 42
909/// ```
910///
911/// This separate type gives the code that needs to interact with coercions very
912/// fine-grained control over what coercions happen and when.
913///
914/// The primary driver of coercion is function and operator selection, as
915/// choosing the correct function or operator implementation depends on the type
916/// of the provided arguments. Coercion also occurs at the very root of the
917/// scalar expression tree. For example in
918///
919/// ```sql
920/// SELECT ... WHERE $1
921/// ```
922///
923/// the `WHERE` clause will coerce the contained unconstrained type parameter
924/// `$1` to have type bool.
925#[derive(Clone, Debug)]
926pub enum CoercibleScalarExpr {
927    Coerced(HirScalarExpr),
928    Parameter(usize),
929    LiteralNull,
930    LiteralString(String),
931    LiteralRecord(Vec<CoercibleScalarExpr>),
932}
933
934impl CoercibleScalarExpr {
935    pub fn type_as(
936        self,
937        ecx: &ExprContext,
938        ty: &SqlScalarType,
939    ) -> Result<HirScalarExpr, PlanError> {
940        let expr = typeconv::plan_coerce(ecx, self, ty)?;
941        let expr_ty = ecx.scalar_type(&expr);
942        if ty != &expr_ty {
943            sql_bail!(
944                "{} must have type {}, not type {}",
945                ecx.name,
946                ecx.humanize_scalar_type(ty, false),
947                ecx.humanize_scalar_type(&expr_ty, false),
948            );
949        }
950        Ok(expr)
951    }
952
953    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
954        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
955    }
956
957    pub fn cast_to(
958        self,
959        ecx: &ExprContext,
960        ccx: CastContext,
961        ty: &SqlScalarType,
962    ) -> Result<HirScalarExpr, PlanError> {
963        let expr = typeconv::plan_coerce(ecx, self, ty)?;
964        typeconv::plan_cast(ecx, ccx, expr, ty)
965    }
966}
967
968/// The column type for a [`CoercibleScalarExpr`].
969#[derive(Clone, Debug)]
970pub enum CoercibleColumnType {
971    Coerced(SqlColumnType),
972    Record(Vec<CoercibleColumnType>),
973    Uncoerced,
974}
975
976impl CoercibleColumnType {
977    /// Reports the nullability of the type.
978    pub fn nullable(&self) -> bool {
979        match self {
980            // A coerced value's nullability is known.
981            CoercibleColumnType::Coerced(ct) => ct.nullable,
982
983            // A literal record can never be null.
984            CoercibleColumnType::Record(_) => false,
985
986            // An uncoerced literal may be the literal `NULL`, so we have
987            // to conservatively assume it is nullable.
988            CoercibleColumnType::Uncoerced => true,
989        }
990    }
991}
992
993/// The scalar type for a [`CoercibleScalarExpr`].
994#[derive(Clone, Debug)]
995pub enum CoercibleScalarType {
996    Coerced(SqlScalarType),
997    Record(Vec<CoercibleColumnType>),
998    Uncoerced,
999}
1000
1001impl CoercibleScalarType {
1002    /// Reports whether the scalar type has been coerced.
1003    pub fn is_coerced(&self) -> bool {
1004        matches!(self, CoercibleScalarType::Coerced(_))
1005    }
1006
1007    /// Returns the coerced scalar type, if the type is coerced.
1008    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1009        match self {
1010            CoercibleScalarType::Coerced(t) => Some(t),
1011            _ => None,
1012        }
1013    }
1014
1015    /// If the type is coerced, apply the mapping function to the contained
1016    /// scalar type.
1017    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1018    where
1019        F: FnOnce(SqlScalarType) -> SqlScalarType,
1020    {
1021        match self {
1022            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1023            _ => self,
1024        }
1025    }
1026
1027    /// If the type is an coercible record, forcibly converts to a coerced
1028    /// record type. Any uncoerced field types are assumed to be of type text.
1029    ///
1030    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
1031    /// accepts a type hint that can indicate the types of uncoerced field
1032    /// types.
1033    pub fn force_coerced_if_record(&mut self) {
1034        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1035            let mut fields = vec![];
1036            for (i, uf) in uncoerced_fields.enumerate() {
1037                let name = ColumnName::from(format!("f{}", i + 1));
1038                let ty = match uf {
1039                    CoercibleColumnType::Coerced(ty) => ty,
1040                    CoercibleColumnType::Record(mut fields) => {
1041                        convert(fields.drain(..)).nullable(false)
1042                    }
1043                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1044                };
1045                fields.push((name, ty))
1046            }
1047            SqlScalarType::Record {
1048                fields: fields.into(),
1049                custom_id: None,
1050            }
1051        }
1052
1053        if let CoercibleScalarType::Record(fields) = self {
1054            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1055        }
1056    }
1057}
1058
1059/// An expression whose type can be ascertained.
1060///
1061/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
1062pub trait AbstractExpr {
1063    type Type: AbstractColumnType;
1064
1065    /// Computes the type of the expression.
1066    fn typ(
1067        &self,
1068        outers: &[SqlRelationType],
1069        inner: &SqlRelationType,
1070        params: &BTreeMap<usize, SqlScalarType>,
1071    ) -> Self::Type;
1072}
1073
1074impl AbstractExpr for CoercibleScalarExpr {
1075    type Type = CoercibleColumnType;
1076
1077    fn typ(
1078        &self,
1079        outers: &[SqlRelationType],
1080        inner: &SqlRelationType,
1081        params: &BTreeMap<usize, SqlScalarType>,
1082    ) -> Self::Type {
1083        match self {
1084            CoercibleScalarExpr::Coerced(expr) => {
1085                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1086            }
1087            CoercibleScalarExpr::LiteralRecord(scalars) => {
1088                let fields = scalars
1089                    .iter()
1090                    .map(|s| s.typ(outers, inner, params))
1091                    .collect();
1092                CoercibleColumnType::Record(fields)
1093            }
1094            _ => CoercibleColumnType::Uncoerced,
1095        }
1096    }
1097}
1098
1099/// A column type-like object whose underlying scalar type-like object can be
1100/// ascertained.
1101///
1102/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1103pub trait AbstractColumnType {
1104    type AbstractScalarType;
1105
1106    /// Converts the column type-like object into its inner scalar type-like
1107    /// object.
1108    fn scalar_type(self) -> Self::AbstractScalarType;
1109}
1110
1111impl AbstractColumnType for SqlColumnType {
1112    type AbstractScalarType = SqlScalarType;
1113
1114    fn scalar_type(self) -> Self::AbstractScalarType {
1115        self.scalar_type
1116    }
1117}
1118
1119impl AbstractColumnType for CoercibleColumnType {
1120    type AbstractScalarType = CoercibleScalarType;
1121
1122    fn scalar_type(self) -> Self::AbstractScalarType {
1123        match self {
1124            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1125            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1126            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1127        }
1128    }
1129}
1130
1131impl From<HirScalarExpr> for CoercibleScalarExpr {
1132    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1133        CoercibleScalarExpr::Coerced(expr)
1134    }
1135}
1136
1137/// A leveled column reference.
1138///
1139/// In the course of decorrelation, multiple levels of nested subqueries are
1140/// traversed, and references to columns may correspond to different levels
1141/// of containing outer subqueries.
1142///
1143/// A `ColumnRef` allows expressions to refer to columns while being clear
1144/// about which level the column references without manually performing the
1145/// bookkeeping tracking their actual column locations.
1146///
1147/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1148/// from the reference, using `column` as a unique identifier in that subquery level.
1149/// A `level` of zero corresponds to the current scope, and levels increase to
1150/// indicate subqueries further "outwards".
1151#[derive(
1152    Debug,
1153    Clone,
1154    Copy,
1155    PartialEq,
1156    Eq,
1157    Hash,
1158    Ord,
1159    PartialOrd,
1160    Serialize,
1161    Deserialize
1162)]
1163pub struct ColumnRef {
1164    // scope level, where 0 is the current scope and 1+ are outer scopes.
1165    pub level: usize,
1166    // level-local column identifier used.
1167    pub column: usize,
1168}
1169
1170#[derive(
1171    Debug,
1172    Clone,
1173    PartialEq,
1174    Eq,
1175    PartialOrd,
1176    Ord,
1177    Hash,
1178    Serialize,
1179    Deserialize
1180)]
1181pub enum JoinKind {
1182    Inner,
1183    LeftOuter,
1184    RightOuter,
1185    FullOuter,
1186}
1187
1188impl fmt::Display for JoinKind {
1189    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1190        write!(
1191            f,
1192            "{}",
1193            match self {
1194                JoinKind::Inner => "Inner",
1195                JoinKind::LeftOuter => "LeftOuter",
1196                JoinKind::RightOuter => "RightOuter",
1197                JoinKind::FullOuter => "FullOuter",
1198            }
1199        )
1200    }
1201}
1202
1203impl JoinKind {
1204    pub fn can_be_correlated(&self) -> bool {
1205        match self {
1206            JoinKind::Inner | JoinKind::LeftOuter => true,
1207            JoinKind::RightOuter | JoinKind::FullOuter => false,
1208        }
1209    }
1210}
1211
1212#[derive(
1213    Debug,
1214    Clone,
1215    PartialEq,
1216    Eq,
1217    PartialOrd,
1218    Ord,
1219    Hash,
1220    Serialize,
1221    Deserialize
1222)]
1223pub struct AggregateExpr {
1224    pub func: AggregateFunc,
1225    pub expr: Box<HirScalarExpr>,
1226    pub distinct: bool,
1227}
1228
1229/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1230/// types may be different.
1231///
1232/// Specifically, the nullability of the aggregate columns is more common
1233/// here than in `expr`, as these aggregates may be applied over empty
1234/// result sets and should be null in those cases, whereas `expr` variants
1235/// only return null values when supplied nulls as input.
1236#[derive(
1237    Clone,
1238    Debug,
1239    Eq,
1240    PartialEq,
1241    PartialOrd,
1242    Ord,
1243    Hash,
1244    Serialize,
1245    Deserialize
1246)]
1247pub enum AggregateFunc {
1248    MaxNumeric,
1249    MaxInt16,
1250    MaxInt32,
1251    MaxInt64,
1252    MaxUInt16,
1253    MaxUInt32,
1254    MaxUInt64,
1255    MaxMzTimestamp,
1256    MaxFloat32,
1257    MaxFloat64,
1258    MaxBool,
1259    MaxString,
1260    MaxDate,
1261    MaxTimestamp,
1262    MaxTimestampTz,
1263    MaxInterval,
1264    MaxTime,
1265    MinNumeric,
1266    MinInt16,
1267    MinInt32,
1268    MinInt64,
1269    MinUInt16,
1270    MinUInt32,
1271    MinUInt64,
1272    MinMzTimestamp,
1273    MinFloat32,
1274    MinFloat64,
1275    MinBool,
1276    MinString,
1277    MinDate,
1278    MinTimestamp,
1279    MinTimestampTz,
1280    MinInterval,
1281    MinTime,
1282    SumInt16,
1283    SumInt32,
1284    SumInt64,
1285    SumUInt16,
1286    SumUInt32,
1287    SumUInt64,
1288    SumFloat32,
1289    SumFloat64,
1290    SumNumeric,
1291    Count,
1292    Any,
1293    All,
1294    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1295    /// into a JSON list. The other elements are columns used by `order_by`.
1296    ///
1297    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1298    /// layer, this function filters out `Datum::Null`, for consistency with
1299    /// the other aggregate functions.
1300    JsonbAgg {
1301        order_by: Vec<ColumnOrder>,
1302    },
1303    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1304    /// JSON map. The other elements are columns used by `order_by`.
1305    JsonbObjectAgg {
1306        order_by: Vec<ColumnOrder>,
1307    },
1308    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1309    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1310    /// elements are columns used by `order_by`.
1311    MapAgg {
1312        order_by: Vec<ColumnOrder>,
1313        value_type: SqlScalarType,
1314    },
1315    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1316    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1317    ArrayConcat {
1318        order_by: Vec<ColumnOrder>,
1319    },
1320    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1321    /// single `Datum::List`. The other elements are columns used by `order_by`.
1322    ListConcat {
1323        order_by: Vec<ColumnOrder>,
1324    },
1325    StringAgg {
1326        order_by: Vec<ColumnOrder>,
1327    },
1328    /// A bundle of fused window aggregations: its input is a record, whose each
1329    /// component will be the input to one of the `AggregateFunc`s.
1330    ///
1331    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1332    /// more specifically an `AggregateWindowExpr`.
1333    FusedWindowAgg {
1334        funcs: Vec<AggregateFunc>,
1335    },
1336    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1337    ///
1338    /// Useful for removing an expensive aggregation while maintaining the shape
1339    /// of a reduce operator.
1340    Dummy,
1341}
1342
1343impl AggregateFunc {
1344    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1345    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1346        match self {
1347            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1348            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1349            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1350            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1351            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1352            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1353            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1354            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1355            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1356            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1357            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1358            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1359            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1360            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1361            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1362            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1363            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1364            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1365            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1366            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1367            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1368            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1369            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1370            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1371            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1372            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1373            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1374            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1375            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1376            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1377            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1378            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1379            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1380            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1381            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1382            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1383            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1384            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1385            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1386            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1387            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1388            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1389            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1390            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1391            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1392            AggregateFunc::All => mz_expr::AggregateFunc::All,
1393            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1394            AggregateFunc::JsonbObjectAgg { order_by } => {
1395                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1396            }
1397            AggregateFunc::MapAgg {
1398                order_by,
1399                value_type,
1400            } => mz_expr::AggregateFunc::MapAgg {
1401                order_by,
1402                value_type,
1403            },
1404            AggregateFunc::ArrayConcat { order_by } => {
1405                mz_expr::AggregateFunc::ArrayConcat { order_by }
1406            }
1407            AggregateFunc::ListConcat { order_by } => {
1408                mz_expr::AggregateFunc::ListConcat { order_by }
1409            }
1410            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1411            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1412            // `AggregateWindowExpr::into_expr`.
1413            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1414                panic!("into_expr called on FusedWindowAgg")
1415            }
1416            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1417        }
1418    }
1419
1420    /// Returns a datum whose inclusion in the aggregation will not change its
1421    /// result.
1422    ///
1423    /// # Panics
1424    ///
1425    /// Panics if called on a `FusedWindowAgg`.
1426    pub fn identity_datum(&self) -> Datum<'static> {
1427        match self {
1428            AggregateFunc::Any => Datum::False,
1429            AggregateFunc::All => Datum::True,
1430            AggregateFunc::Dummy => Datum::Dummy,
1431            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1432            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1433            AggregateFunc::MaxNumeric
1434            | AggregateFunc::MaxInt16
1435            | AggregateFunc::MaxInt32
1436            | AggregateFunc::MaxInt64
1437            | AggregateFunc::MaxUInt16
1438            | AggregateFunc::MaxUInt32
1439            | AggregateFunc::MaxUInt64
1440            | AggregateFunc::MaxMzTimestamp
1441            | AggregateFunc::MaxFloat32
1442            | AggregateFunc::MaxFloat64
1443            | AggregateFunc::MaxBool
1444            | AggregateFunc::MaxString
1445            | AggregateFunc::MaxDate
1446            | AggregateFunc::MaxTimestamp
1447            | AggregateFunc::MaxTimestampTz
1448            | AggregateFunc::MaxInterval
1449            | AggregateFunc::MaxTime
1450            | AggregateFunc::MinNumeric
1451            | AggregateFunc::MinInt16
1452            | AggregateFunc::MinInt32
1453            | AggregateFunc::MinInt64
1454            | AggregateFunc::MinUInt16
1455            | AggregateFunc::MinUInt32
1456            | AggregateFunc::MinUInt64
1457            | AggregateFunc::MinMzTimestamp
1458            | AggregateFunc::MinFloat32
1459            | AggregateFunc::MinFloat64
1460            | AggregateFunc::MinBool
1461            | AggregateFunc::MinString
1462            | AggregateFunc::MinDate
1463            | AggregateFunc::MinTimestamp
1464            | AggregateFunc::MinTimestampTz
1465            | AggregateFunc::MinInterval
1466            | AggregateFunc::MinTime
1467            | AggregateFunc::SumInt16
1468            | AggregateFunc::SumInt32
1469            | AggregateFunc::SumInt64
1470            | AggregateFunc::SumUInt16
1471            | AggregateFunc::SumUInt32
1472            | AggregateFunc::SumUInt64
1473            | AggregateFunc::SumFloat32
1474            | AggregateFunc::SumFloat64
1475            | AggregateFunc::SumNumeric
1476            | AggregateFunc::Count
1477            | AggregateFunc::JsonbAgg { .. }
1478            | AggregateFunc::JsonbObjectAgg { .. }
1479            | AggregateFunc::MapAgg { .. }
1480            | AggregateFunc::StringAgg { .. } => Datum::Null,
1481            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1482                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1483                // in HIR planning, because it is introduced only during HIR transformation.
1484                //
1485                // The implementation could be something like the following, except that we need to
1486                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1487                // ```
1488                // let temp_storage = RowArena::new();
1489                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1490                // ```
1491                panic!("FusedWindowAgg doesn't have an identity_datum")
1492            }
1493        }
1494    }
1495
1496    /// The output column type for the result of an aggregation.
1497    ///
1498    /// The output column type also contains nullability information, which
1499    /// is (without further information) true for aggregations that are not
1500    /// counts.
1501    pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1502        let scalar_type = match self {
1503            AggregateFunc::Count => SqlScalarType::Int64,
1504            AggregateFunc::Any => SqlScalarType::Bool,
1505            AggregateFunc::All => SqlScalarType::Bool,
1506            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1507            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1508            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1509            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1510            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1511                max_scale: Some(NumericMaxScale::ZERO),
1512            },
1513            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1514            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1515                max_scale: Some(NumericMaxScale::ZERO),
1516            },
1517            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1518                value_type: Box::new(value_type.clone()),
1519                custom_id: None,
1520            },
1521            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1522                match input_type.scalar_type {
1523                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1524                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1525                    _ => unreachable!(),
1526                }
1527            }
1528            AggregateFunc::MaxNumeric
1529            | AggregateFunc::MaxInt16
1530            | AggregateFunc::MaxInt32
1531            | AggregateFunc::MaxInt64
1532            | AggregateFunc::MaxUInt16
1533            | AggregateFunc::MaxUInt32
1534            | AggregateFunc::MaxUInt64
1535            | AggregateFunc::MaxMzTimestamp
1536            | AggregateFunc::MaxFloat32
1537            | AggregateFunc::MaxFloat64
1538            | AggregateFunc::MaxBool
1539            | AggregateFunc::MaxString
1540            | AggregateFunc::MaxDate
1541            | AggregateFunc::MaxTimestamp
1542            | AggregateFunc::MaxTimestampTz
1543            | AggregateFunc::MaxInterval
1544            | AggregateFunc::MaxTime
1545            | AggregateFunc::MinNumeric
1546            | AggregateFunc::MinInt16
1547            | AggregateFunc::MinInt32
1548            | AggregateFunc::MinInt64
1549            | AggregateFunc::MinUInt16
1550            | AggregateFunc::MinUInt32
1551            | AggregateFunc::MinUInt64
1552            | AggregateFunc::MinMzTimestamp
1553            | AggregateFunc::MinFloat32
1554            | AggregateFunc::MinFloat64
1555            | AggregateFunc::MinBool
1556            | AggregateFunc::MinString
1557            | AggregateFunc::MinDate
1558            | AggregateFunc::MinTimestamp
1559            | AggregateFunc::MinTimestampTz
1560            | AggregateFunc::MinInterval
1561            | AggregateFunc::MinTime
1562            | AggregateFunc::SumFloat32
1563            | AggregateFunc::SumFloat64
1564            | AggregateFunc::SumNumeric
1565            | AggregateFunc::Dummy => input_type.scalar_type,
1566            AggregateFunc::FusedWindowAgg { funcs } => {
1567                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1568                SqlScalarType::Record {
1569                    fields: funcs
1570                        .iter()
1571                        .zip_eq(input_types)
1572                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1573                        .collect(),
1574                    custom_id: None,
1575                }
1576            }
1577        };
1578        // max/min/sum return null on empty sets
1579        let nullable = !matches!(self, AggregateFunc::Count);
1580        scalar_type.nullable(nullable)
1581    }
1582
1583    pub fn is_order_sensitive(&self) -> bool {
1584        use AggregateFunc::*;
1585        matches!(
1586            self,
1587            JsonbAgg { .. }
1588                | JsonbObjectAgg { .. }
1589                | MapAgg { .. }
1590                | ArrayConcat { .. }
1591                | ListConcat { .. }
1592                | StringAgg { .. }
1593        )
1594    }
1595}
1596
1597impl HirRelationExpr {
1598    /// Gets the SQL type of a self-contained, top-level expression.
1599    pub fn top_level_typ(&self) -> SqlRelationType {
1600        self.typ(&[], &BTreeMap::new())
1601    }
1602
1603    /// Gets the SQL type of the expression.
1604    ///
1605    /// `outers` gives types for outer relations.
1606    /// `params` gives types for parameters.
1607    pub fn typ(
1608        &self,
1609        outers: &[SqlRelationType],
1610        params: &BTreeMap<usize, SqlScalarType>,
1611    ) -> SqlRelationType {
1612        stack::maybe_grow(|| match self {
1613            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1614            HirRelationExpr::Get { typ, .. } => typ.clone(),
1615            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1616            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1617            HirRelationExpr::Project { input, outputs } => {
1618                let input_typ = input.typ(outers, params);
1619                SqlRelationType::new(
1620                    outputs
1621                        .iter()
1622                        .map(|&i| input_typ.column_types[i].clone())
1623                        .collect(),
1624                )
1625            }
1626            HirRelationExpr::Map { input, scalars } => {
1627                let mut typ = input.typ(outers, params);
1628                for scalar in scalars {
1629                    typ.column_types.push(scalar.typ(outers, &typ, params));
1630                }
1631                typ
1632            }
1633            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1634            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1635                input.typ(outers, params)
1636            }
1637            HirRelationExpr::Join {
1638                left, right, kind, ..
1639            } => {
1640                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1641                let right_nullable =
1642                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1643                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1644                    let nullable = t.nullable || left_nullable;
1645                    t.nullable(nullable)
1646                });
1647                let mut outers = outers.to_vec();
1648                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1649                let rt = right
1650                    .typ(&outers, params)
1651                    .column_types
1652                    .into_iter()
1653                    .map(|t| {
1654                        let nullable = t.nullable || right_nullable;
1655                        t.nullable(nullable)
1656                    });
1657                SqlRelationType::new(lt.chain(rt).collect())
1658            }
1659            HirRelationExpr::Reduce {
1660                input,
1661                group_key,
1662                aggregates,
1663                expected_group_size: _,
1664            } => {
1665                let input_typ = input.typ(outers, params);
1666                let mut column_types = group_key
1667                    .iter()
1668                    .map(|&i| input_typ.column_types[i].clone())
1669                    .collect::<Vec<_>>();
1670                for agg in aggregates {
1671                    column_types.push(agg.typ(outers, &input_typ, params));
1672                }
1673                // TODO(frank): add primary key information.
1674                SqlRelationType::new(column_types)
1675            }
1676            // TODO(frank): check for removal; add primary key information.
1677            HirRelationExpr::Distinct { input }
1678            | HirRelationExpr::Negate { input }
1679            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1680            HirRelationExpr::Union { base, inputs } => {
1681                let mut base_cols = base.typ(outers, params).column_types;
1682                for input in inputs {
1683                    for (base_col, col) in base_cols
1684                        .iter_mut()
1685                        .zip_eq(input.typ(outers, params).column_types)
1686                    {
1687                        *base_col = base_col.sql_union(&col).unwrap(); // HIR deliberately not using `union`
1688                    }
1689                }
1690                SqlRelationType::new(base_cols)
1691            }
1692        })
1693    }
1694
1695    pub fn arity(&self) -> usize {
1696        match self {
1697            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1698            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1699            HirRelationExpr::Let { body, .. } => body.arity(),
1700            HirRelationExpr::LetRec { body, .. } => body.arity(),
1701            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1702            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1703            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1704            HirRelationExpr::Filter { input, .. }
1705            | HirRelationExpr::TopK { input, .. }
1706            | HirRelationExpr::Distinct { input }
1707            | HirRelationExpr::Negate { input }
1708            | HirRelationExpr::Threshold { input } => input.arity(),
1709            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1710            HirRelationExpr::Union { base, .. } => base.arity(),
1711            HirRelationExpr::Reduce {
1712                group_key,
1713                aggregates,
1714                ..
1715            } => group_key.len() + aggregates.len(),
1716        }
1717    }
1718
1719    /// If self is a constant, return the value and the type, otherwise `None`.
1720    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1721        match self {
1722            Self::Constant { rows, typ } => Some((rows, typ)),
1723            _ => None,
1724        }
1725    }
1726
1727    /// Reports whether this expression contains a column reference to its
1728    /// direct parent scope.
1729    pub fn is_correlated(&self) -> bool {
1730        let mut correlated = false;
1731        #[allow(deprecated)]
1732        self.visit_columns(0, &mut |depth, col| {
1733            if col.level > depth && col.level - depth == 1 {
1734                correlated = true;
1735            }
1736        });
1737        correlated
1738    }
1739
1740    pub fn is_join_identity(&self) -> bool {
1741        match self {
1742            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1743            _ => false,
1744        }
1745    }
1746
1747    pub fn project(self, outputs: Vec<usize>) -> Self {
1748        if outputs.iter().copied().eq(0..self.arity()) {
1749            // The projection is trivial. Suppress it.
1750            self
1751        } else {
1752            HirRelationExpr::Project {
1753                input: Box::new(self),
1754                outputs,
1755            }
1756        }
1757    }
1758
1759    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1760        if scalars.is_empty() {
1761            // The map is trivial. Suppress it.
1762            self
1763        } else if let HirRelationExpr::Map {
1764            scalars: old_scalars,
1765            input: _,
1766        } = &mut self
1767        {
1768            // Map applied to a map. Fuse the maps.
1769            old_scalars.extend(scalars);
1770            self
1771        } else {
1772            HirRelationExpr::Map {
1773                input: Box::new(self),
1774                scalars,
1775            }
1776        }
1777    }
1778
1779    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1780        if let HirRelationExpr::Filter {
1781            input: _,
1782            predicates,
1783        } = &mut self
1784        {
1785            predicates.extend(preds);
1786            predicates.sort();
1787            predicates.dedup();
1788            self
1789        } else {
1790            preds.sort();
1791            preds.dedup();
1792            HirRelationExpr::Filter {
1793                input: Box::new(self),
1794                predicates: preds,
1795            }
1796        }
1797    }
1798
1799    pub fn reduce(
1800        self,
1801        group_key: Vec<usize>,
1802        aggregates: Vec<AggregateExpr>,
1803        expected_group_size: Option<u64>,
1804    ) -> Self {
1805        HirRelationExpr::Reduce {
1806            input: Box::new(self),
1807            group_key,
1808            aggregates,
1809            expected_group_size,
1810        }
1811    }
1812
1813    pub fn top_k(
1814        self,
1815        group_key: Vec<usize>,
1816        order_key: Vec<ColumnOrder>,
1817        limit: Option<HirScalarExpr>,
1818        offset: HirScalarExpr,
1819        expected_group_size: Option<u64>,
1820    ) -> Self {
1821        HirRelationExpr::TopK {
1822            input: Box::new(self),
1823            group_key,
1824            order_key,
1825            limit,
1826            offset,
1827            expected_group_size,
1828        }
1829    }
1830
1831    pub fn negate(self) -> Self {
1832        if let HirRelationExpr::Negate { input } = self {
1833            *input
1834        } else {
1835            HirRelationExpr::Negate {
1836                input: Box::new(self),
1837            }
1838        }
1839    }
1840
1841    pub fn distinct(self) -> Self {
1842        if let HirRelationExpr::Distinct { .. } = self {
1843            self
1844        } else {
1845            HirRelationExpr::Distinct {
1846                input: Box::new(self),
1847            }
1848        }
1849    }
1850
1851    pub fn threshold(self) -> Self {
1852        if let HirRelationExpr::Threshold { .. } = self {
1853            self
1854        } else {
1855            HirRelationExpr::Threshold {
1856                input: Box::new(self),
1857            }
1858        }
1859    }
1860
1861    pub fn union(self, other: Self) -> Self {
1862        let mut terms = Vec::new();
1863        if let HirRelationExpr::Union { base, inputs } = self {
1864            terms.push(*base);
1865            terms.extend(inputs);
1866        } else {
1867            terms.push(self);
1868        }
1869        if let HirRelationExpr::Union { base, inputs } = other {
1870            terms.push(*base);
1871            terms.extend(inputs);
1872        } else {
1873            terms.push(other);
1874        }
1875        HirRelationExpr::Union {
1876            base: Box::new(terms.remove(0)),
1877            inputs: terms,
1878        }
1879    }
1880
1881    pub fn exists(self) -> HirScalarExpr {
1882        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1883    }
1884
1885    pub fn select(self) -> HirScalarExpr {
1886        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1887    }
1888
1889    pub fn join(
1890        self,
1891        mut right: HirRelationExpr,
1892        on: HirScalarExpr,
1893        kind: JoinKind,
1894    ) -> HirRelationExpr {
1895        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1896        {
1897            // The join can be elided, but we need to adjust column references
1898            // on the right-hand side to account for the removal of the scope
1899            // introduced by the join.
1900            #[allow(deprecated)]
1901            right.visit_columns_mut(0, &mut |depth, col| {
1902                if col.level > depth {
1903                    col.level -= 1;
1904                }
1905            });
1906            right
1907        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1908            self
1909        } else {
1910            HirRelationExpr::Join {
1911                left: Box::new(self),
1912                right: Box::new(right),
1913                on,
1914                kind,
1915            }
1916        }
1917    }
1918
1919    pub fn take(&mut self) -> HirRelationExpr {
1920        mem::replace(
1921            self,
1922            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1923        )
1924    }
1925
1926    #[deprecated = "Use `Visit::visit_post`."]
1927    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1928    where
1929        F: FnMut(&'a Self, usize),
1930    {
1931        #[allow(deprecated)]
1932        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1933                                                 depth: usize|
1934         -> Result<(), ()> {
1935            f(e, depth);
1936            Ok(())
1937        });
1938    }
1939
1940    #[deprecated = "Use `Visit::try_visit_post`."]
1941    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1942    where
1943        F: FnMut(&'a Self, usize) -> Result<(), E>,
1944    {
1945        #[allow(deprecated)]
1946        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1947            e.visit_fallible(depth, f)
1948        })?;
1949        f(self, depth)
1950    }
1951
1952    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1953    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1954    where
1955        F: FnMut(&'a Self, usize) -> Result<(), E>,
1956    {
1957        match self {
1958            HirRelationExpr::Constant { .. }
1959            | HirRelationExpr::Get { .. }
1960            | HirRelationExpr::CallTable { .. } => (),
1961            HirRelationExpr::Let { body, value, .. } => {
1962                f(value, depth)?;
1963                f(body, depth)?;
1964            }
1965            HirRelationExpr::LetRec {
1966                limit: _,
1967                bindings,
1968                body,
1969            } => {
1970                for (_, _, value, _) in bindings.iter() {
1971                    f(value, depth)?;
1972                }
1973                f(body, depth)?;
1974            }
1975            HirRelationExpr::Project { input, .. } => {
1976                f(input, depth)?;
1977            }
1978            HirRelationExpr::Map { input, .. } => {
1979                f(input, depth)?;
1980            }
1981            HirRelationExpr::Filter { input, .. } => {
1982                f(input, depth)?;
1983            }
1984            HirRelationExpr::Join { left, right, .. } => {
1985                f(left, depth)?;
1986                f(right, depth + 1)?;
1987            }
1988            HirRelationExpr::Reduce { input, .. } => {
1989                f(input, depth)?;
1990            }
1991            HirRelationExpr::Distinct { input } => {
1992                f(input, depth)?;
1993            }
1994            HirRelationExpr::TopK { input, .. } => {
1995                f(input, depth)?;
1996            }
1997            HirRelationExpr::Negate { input } => {
1998                f(input, depth)?;
1999            }
2000            HirRelationExpr::Threshold { input } => {
2001                f(input, depth)?;
2002            }
2003            HirRelationExpr::Union { base, inputs } => {
2004                f(base, depth)?;
2005                for input in inputs {
2006                    f(input, depth)?;
2007                }
2008            }
2009        }
2010        Ok(())
2011    }
2012
2013    #[deprecated = "Use `Visit::visit_mut_post` instead."]
2014    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2015    where
2016        F: FnMut(&mut Self, usize),
2017    {
2018        #[allow(deprecated)]
2019        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2020                                                     depth: usize|
2021         -> Result<(), ()> {
2022            f(e, depth);
2023            Ok(())
2024        });
2025    }
2026
2027    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2028    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2029    where
2030        F: FnMut(&mut Self, usize) -> Result<(), E>,
2031    {
2032        #[allow(deprecated)]
2033        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2034            e.visit_mut_fallible(depth, f)
2035        })?;
2036        f(self, depth)
2037    }
2038
2039    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2040    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2041    where
2042        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2043    {
2044        match self {
2045            HirRelationExpr::Constant { .. }
2046            | HirRelationExpr::Get { .. }
2047            | HirRelationExpr::CallTable { .. } => (),
2048            HirRelationExpr::Let { body, value, .. } => {
2049                f(value, depth)?;
2050                f(body, depth)?;
2051            }
2052            HirRelationExpr::LetRec {
2053                limit: _,
2054                bindings,
2055                body,
2056            } => {
2057                for (_, _, value, _) in bindings.iter_mut() {
2058                    f(value, depth)?;
2059                }
2060                f(body, depth)?;
2061            }
2062            HirRelationExpr::Project { input, .. } => {
2063                f(input, depth)?;
2064            }
2065            HirRelationExpr::Map { input, .. } => {
2066                f(input, depth)?;
2067            }
2068            HirRelationExpr::Filter { input, .. } => {
2069                f(input, depth)?;
2070            }
2071            HirRelationExpr::Join { left, right, .. } => {
2072                f(left, depth)?;
2073                f(right, depth + 1)?;
2074            }
2075            HirRelationExpr::Reduce { input, .. } => {
2076                f(input, depth)?;
2077            }
2078            HirRelationExpr::Distinct { input } => {
2079                f(input, depth)?;
2080            }
2081            HirRelationExpr::TopK { input, .. } => {
2082                f(input, depth)?;
2083            }
2084            HirRelationExpr::Negate { input } => {
2085                f(input, depth)?;
2086            }
2087            HirRelationExpr::Threshold { input } => {
2088                f(input, depth)?;
2089            }
2090            HirRelationExpr::Union { base, inputs } => {
2091                f(base, depth)?;
2092                for input in inputs {
2093                    f(input, depth)?;
2094                }
2095            }
2096        }
2097        Ok(())
2098    }
2099
2100    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2101    /// Visits all scalar expressions within the sub-tree of the given relation.
2102    ///
2103    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2104    /// which will be incremented when entering the RHS of a join or a subquery and
2105    /// presented to the supplied function `f`.
2106    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2107    where
2108        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2109    {
2110        #[allow(deprecated)]
2111        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2112                                         depth: usize|
2113         -> Result<(), E> {
2114            match e {
2115                HirRelationExpr::Join { on, .. } => {
2116                    f(on, depth)?;
2117                }
2118                HirRelationExpr::Map { scalars, .. } => {
2119                    for scalar in scalars {
2120                        f(scalar, depth)?;
2121                    }
2122                }
2123                HirRelationExpr::CallTable { exprs, .. } => {
2124                    for expr in exprs {
2125                        f(expr, depth)?;
2126                    }
2127                }
2128                HirRelationExpr::Filter { predicates, .. } => {
2129                    for predicate in predicates {
2130                        f(predicate, depth)?;
2131                    }
2132                }
2133                HirRelationExpr::Reduce { aggregates, .. } => {
2134                    for aggregate in aggregates {
2135                        f(&aggregate.expr, depth)?;
2136                    }
2137                }
2138                HirRelationExpr::TopK { limit, offset, .. } => {
2139                    if let Some(limit) = limit {
2140                        f(limit, depth)?;
2141                    }
2142                    f(offset, depth)?;
2143                }
2144                HirRelationExpr::Union { .. }
2145                | HirRelationExpr::Let { .. }
2146                | HirRelationExpr::LetRec { .. }
2147                | HirRelationExpr::Project { .. }
2148                | HirRelationExpr::Distinct { .. }
2149                | HirRelationExpr::Negate { .. }
2150                | HirRelationExpr::Threshold { .. }
2151                | HirRelationExpr::Constant { .. }
2152                | HirRelationExpr::Get { .. } => (),
2153            }
2154            Ok(())
2155        })
2156    }
2157
2158    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2159    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2160    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2161    where
2162        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2163    {
2164        #[allow(deprecated)]
2165        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2166                                             depth: usize|
2167         -> Result<(), E> {
2168            match e {
2169                HirRelationExpr::Join { on, .. } => {
2170                    f(on, depth)?;
2171                }
2172                HirRelationExpr::Map { scalars, .. } => {
2173                    for scalar in scalars.iter_mut() {
2174                        f(scalar, depth)?;
2175                    }
2176                }
2177                HirRelationExpr::CallTable { exprs, .. } => {
2178                    for expr in exprs.iter_mut() {
2179                        f(expr, depth)?;
2180                    }
2181                }
2182                HirRelationExpr::Filter { predicates, .. } => {
2183                    for predicate in predicates.iter_mut() {
2184                        f(predicate, depth)?;
2185                    }
2186                }
2187                HirRelationExpr::Reduce { aggregates, .. } => {
2188                    for aggregate in aggregates.iter_mut() {
2189                        f(&mut aggregate.expr, depth)?;
2190                    }
2191                }
2192                HirRelationExpr::TopK { limit, offset, .. } => {
2193                    if let Some(limit) = limit {
2194                        f(limit, depth)?;
2195                    }
2196                    f(offset, depth)?;
2197                }
2198                HirRelationExpr::Union { .. }
2199                | HirRelationExpr::Let { .. }
2200                | HirRelationExpr::LetRec { .. }
2201                | HirRelationExpr::Project { .. }
2202                | HirRelationExpr::Distinct { .. }
2203                | HirRelationExpr::Negate { .. }
2204                | HirRelationExpr::Threshold { .. }
2205                | HirRelationExpr::Constant { .. }
2206                | HirRelationExpr::Get { .. } => (),
2207            }
2208            Ok(())
2209        })
2210    }
2211
2212    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2213    /// Visits the column references in this relation expression.
2214    ///
2215    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2216    /// which will be incremented when entering the RHS of a join or a subquery and
2217    /// presented to the supplied function `f`.
2218    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2219    where
2220        F: FnMut(usize, &ColumnRef),
2221    {
2222        #[allow(deprecated)]
2223        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2224                                                           depth: usize|
2225         -> Result<(), ()> {
2226            e.visit_columns(depth, f);
2227            Ok(())
2228        });
2229    }
2230
2231    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2232    /// Like `visit_columns`, but permits mutating the column references.
2233    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2234    where
2235        F: FnMut(usize, &mut ColumnRef),
2236    {
2237        #[allow(deprecated)]
2238        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2239                                                               depth: usize|
2240         -> Result<(), ()> {
2241            e.visit_columns_mut(depth, f);
2242            Ok(())
2243        });
2244    }
2245
2246    /// Replaces any parameter references in the expression with the
2247    /// corresponding datum from `params`.
2248    pub fn bind_parameters(
2249        &mut self,
2250        scx: &StatementContext,
2251        lifetime: QueryLifetime,
2252        params: &Params,
2253    ) -> Result<(), PlanError> {
2254        #[allow(deprecated)]
2255        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2256            e.bind_parameters(scx, lifetime, params)
2257        })
2258    }
2259
2260    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2261        let mut contains_parameters = false;
2262        #[allow(deprecated)]
2263        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2264            if e.contains_parameters() {
2265                contains_parameters = true;
2266            }
2267            Ok::<(), PlanError>(())
2268        })?;
2269        Ok(contains_parameters)
2270    }
2271
2272    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2273    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2274        #[allow(deprecated)]
2275        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2276                                                               depth: usize|
2277         -> Result<(), ()> {
2278            e.splice_parameters(params, depth);
2279            Ok(())
2280        });
2281    }
2282
2283    /// Constructs a constant collection from specific rows and schema.
2284    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2285        let rows = rows
2286            .into_iter()
2287            .map(move |datums| Row::pack_slice(&datums))
2288            .collect();
2289        HirRelationExpr::Constant { rows, typ }
2290    }
2291
2292    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2293    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2294    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2295    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2296    /// finishing into a trivial finishing.
2297    pub fn finish_maintained(
2298        &mut self,
2299        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2300        group_size_hints: GroupSizeHints,
2301    ) {
2302        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2303            let old_finishing = mem::replace(
2304                finishing,
2305                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2306            );
2307            *self = HirRelationExpr::top_k(
2308                std::mem::replace(
2309                    self,
2310                    HirRelationExpr::Constant {
2311                        rows: vec![],
2312                        typ: SqlRelationType::new(Vec::new()),
2313                    },
2314                ),
2315                vec![],
2316                old_finishing.order_by,
2317                old_finishing.limit,
2318                old_finishing.offset,
2319                group_size_hints.limit_input_group_size,
2320            )
2321            .project(old_finishing.project);
2322        }
2323    }
2324
2325    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2326    ///
2327    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2328    /// parameter is not an HirScalarExpr anymore.)
2329    pub fn trivial_row_set_finishing_hir(
2330        arity: usize,
2331    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2332        RowSetFinishing {
2333            order_by: Vec::new(),
2334            limit: None,
2335            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2336            project: (0..arity).collect(),
2337        }
2338    }
2339
2340    /// True if the finishing does nothing to any result set.
2341    ///
2342    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2343    /// parameter is not an HirScalarExpr anymore.)
2344    pub fn is_trivial_row_set_finishing_hir(
2345        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2346        arity: usize,
2347    ) -> bool {
2348        rsf.limit.is_none()
2349            && rsf.order_by.is_empty()
2350            && rsf
2351                .offset
2352                .clone()
2353                .try_into_literal_int64()
2354                .is_ok_and(|o| o == 0)
2355            && rsf.project.iter().copied().eq(0..arity)
2356    }
2357
2358    /// The HirRelationExpr is considered potentially expensive if and only if
2359    /// at least one of the following conditions is true:
2360    ///
2361    ///  - It contains at least one CallTable or a Reduce operator.
2362    ///  - It contains at least one HirScalarExpr with a function call.
2363    ///
2364    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2365    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2366    pub fn could_run_expensive_function(&self) -> bool {
2367        let mut result = false;
2368        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2369            use HirRelationExpr::*;
2370            use HirScalarExpr::*;
2371
2372            self.visit_children(|scalar: &HirScalarExpr| {
2373                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2374                    result |= match scalar {
2375                        Column(..)
2376                        | Literal(..)
2377                        | CallUnmaterializable(..)
2378                        | If { .. }
2379                        | Parameter(..)
2380                        | Select(..)
2381                        | Exists(..) => false,
2382                        // Function calls are considered expensive
2383                        CallUnary { .. }
2384                        | CallBinary { .. }
2385                        | CallVariadic { .. }
2386                        | Windowing(..) => true,
2387                    };
2388                }) {
2389                    // Conservatively set `true` on RecursionLimitError.
2390                    result = true;
2391                }
2392            });
2393
2394            // CallTable has a table function; Reduce has an aggregate function.
2395            // Other constructs use MirScalarExpr to run a function
2396            result |= matches!(e, CallTable { .. } | Reduce { .. });
2397        }) {
2398            // Conservatively set `true` on RecursionLimitError.
2399            result = true;
2400        }
2401
2402        result
2403    }
2404
2405    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2406    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2407        let mut contains = false;
2408        self.visit_post(&mut |expr| {
2409            expr.visit_children(|expr: &HirScalarExpr| {
2410                contains = contains || expr.contains_temporal()
2411            })
2412        })?;
2413        Ok(contains)
2414    }
2415
2416    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2417    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2418        let mut contains = false;
2419        self.visit_post(&mut |expr| {
2420            expr.visit_children(|expr: &HirScalarExpr| {
2421                contains = contains || expr.contains_unmaterializable()
2422            })
2423        })?;
2424        Ok(contains)
2425    }
2426}
2427
2428impl CollectionPlan for HirRelationExpr {
2429    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2430    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2431    /// (It does explore inside subqueries.)
2432    ///
2433    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2434    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2435    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2436        if let Self::Get {
2437            id: Id::Global(id), ..
2438        } = self
2439        {
2440            out.insert(*id);
2441        }
2442        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2443    }
2444}
2445
2446impl VisitChildren<Self> for HirRelationExpr {
2447    fn visit_children<F>(&self, mut f: F)
2448    where
2449        F: FnMut(&Self),
2450    {
2451        // subqueries of type HirRelationExpr might be wrapped in
2452        // Exists or Select variants within HirScalarExpr trees
2453        // attached at the current node, and we want to visit them as well
2454        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2455            #[allow(deprecated)]
2456            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2457                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2458                    f(expr.as_ref())
2459                }
2460                _ => (),
2461            });
2462        });
2463
2464        use HirRelationExpr::*;
2465        match self {
2466            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2467            Let {
2468                name: _,
2469                id: _,
2470                value,
2471                body,
2472            } => {
2473                f(value);
2474                f(body);
2475            }
2476            LetRec {
2477                limit: _,
2478                bindings,
2479                body,
2480            } => {
2481                for (_, _, value, _) in bindings.iter() {
2482                    f(value);
2483                }
2484                f(body);
2485            }
2486            Project { input, outputs: _ } => f(input),
2487            Map { input, scalars: _ } => {
2488                f(input);
2489            }
2490            CallTable { func: _, exprs: _ } => (),
2491            Filter {
2492                input,
2493                predicates: _,
2494            } => {
2495                f(input);
2496            }
2497            Join {
2498                left,
2499                right,
2500                on: _,
2501                kind: _,
2502            } => {
2503                f(left);
2504                f(right);
2505            }
2506            Reduce {
2507                input,
2508                group_key: _,
2509                aggregates: _,
2510                expected_group_size: _,
2511            } => {
2512                f(input);
2513            }
2514            Distinct { input }
2515            | TopK {
2516                input,
2517                group_key: _,
2518                order_key: _,
2519                limit: _,
2520                offset: _,
2521                expected_group_size: _,
2522            }
2523            | Negate { input }
2524            | Threshold { input } => {
2525                f(input);
2526            }
2527            Union { base, inputs } => {
2528                f(base);
2529                for input in inputs {
2530                    f(input);
2531                }
2532            }
2533        }
2534    }
2535
2536    fn visit_mut_children<F>(&mut self, mut f: F)
2537    where
2538        F: FnMut(&mut Self),
2539    {
2540        // subqueries of type HirRelationExpr might be wrapped in
2541        // Exists or Select variants within HirScalarExpr trees
2542        // attached at the current node, and we want to visit them as well
2543        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2544            #[allow(deprecated)]
2545            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2546                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2547                    f(expr.as_mut())
2548                }
2549                _ => (),
2550            });
2551        });
2552
2553        use HirRelationExpr::*;
2554        match self {
2555            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2556            Let {
2557                name: _,
2558                id: _,
2559                value,
2560                body,
2561            } => {
2562                f(value);
2563                f(body);
2564            }
2565            LetRec {
2566                limit: _,
2567                bindings,
2568                body,
2569            } => {
2570                for (_, _, value, _) in bindings.iter_mut() {
2571                    f(value);
2572                }
2573                f(body);
2574            }
2575            Project { input, outputs: _ } => f(input),
2576            Map { input, scalars: _ } => {
2577                f(input);
2578            }
2579            CallTable { func: _, exprs: _ } => (),
2580            Filter {
2581                input,
2582                predicates: _,
2583            } => {
2584                f(input);
2585            }
2586            Join {
2587                left,
2588                right,
2589                on: _,
2590                kind: _,
2591            } => {
2592                f(left);
2593                f(right);
2594            }
2595            Reduce {
2596                input,
2597                group_key: _,
2598                aggregates: _,
2599                expected_group_size: _,
2600            } => {
2601                f(input);
2602            }
2603            Distinct { input }
2604            | TopK {
2605                input,
2606                group_key: _,
2607                order_key: _,
2608                limit: _,
2609                offset: _,
2610                expected_group_size: _,
2611            }
2612            | Negate { input }
2613            | Threshold { input } => {
2614                f(input);
2615            }
2616            Union { base, inputs } => {
2617                f(base);
2618                for input in inputs {
2619                    f(input);
2620                }
2621            }
2622        }
2623    }
2624
2625    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2626    where
2627        F: FnMut(&Self) -> Result<(), E>,
2628        E: From<RecursionLimitError>,
2629    {
2630        // subqueries of type HirRelationExpr might be wrapped in
2631        // Exists or Select variants within HirScalarExpr trees
2632        // attached at the current node, and we want to visit them as well
2633        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2634            Visit::try_visit_post(expr, &mut |expr| match expr {
2635                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2636                    f(expr.as_ref())
2637                }
2638                _ => Ok(()),
2639            })
2640        })?;
2641
2642        use HirRelationExpr::*;
2643        match self {
2644            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2645            Let {
2646                name: _,
2647                id: _,
2648                value,
2649                body,
2650            } => {
2651                f(value)?;
2652                f(body)?;
2653            }
2654            LetRec {
2655                limit: _,
2656                bindings,
2657                body,
2658            } => {
2659                for (_, _, value, _) in bindings.iter() {
2660                    f(value)?;
2661                }
2662                f(body)?;
2663            }
2664            Project { input, outputs: _ } => f(input)?,
2665            Map { input, scalars: _ } => {
2666                f(input)?;
2667            }
2668            CallTable { func: _, exprs: _ } => (),
2669            Filter {
2670                input,
2671                predicates: _,
2672            } => {
2673                f(input)?;
2674            }
2675            Join {
2676                left,
2677                right,
2678                on: _,
2679                kind: _,
2680            } => {
2681                f(left)?;
2682                f(right)?;
2683            }
2684            Reduce {
2685                input,
2686                group_key: _,
2687                aggregates: _,
2688                expected_group_size: _,
2689            } => {
2690                f(input)?;
2691            }
2692            Distinct { input }
2693            | TopK {
2694                input,
2695                group_key: _,
2696                order_key: _,
2697                limit: _,
2698                offset: _,
2699                expected_group_size: _,
2700            }
2701            | Negate { input }
2702            | Threshold { input } => {
2703                f(input)?;
2704            }
2705            Union { base, inputs } => {
2706                f(base)?;
2707                for input in inputs {
2708                    f(input)?;
2709                }
2710            }
2711        }
2712        Ok(())
2713    }
2714
2715    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2716    where
2717        F: FnMut(&mut Self) -> Result<(), E>,
2718        E: From<RecursionLimitError>,
2719    {
2720        // subqueries of type HirRelationExpr might be wrapped in
2721        // Exists or Select variants within HirScalarExpr trees
2722        // attached at the current node, and we want to visit them as well
2723        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2724            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2725                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2726                    f(expr.as_mut())
2727                }
2728                _ => Ok(()),
2729            })
2730        })?;
2731
2732        use HirRelationExpr::*;
2733        match self {
2734            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2735            Let {
2736                name: _,
2737                id: _,
2738                value,
2739                body,
2740            } => {
2741                f(value)?;
2742                f(body)?;
2743            }
2744            LetRec {
2745                limit: _,
2746                bindings,
2747                body,
2748            } => {
2749                for (_, _, value, _) in bindings.iter_mut() {
2750                    f(value)?;
2751                }
2752                f(body)?;
2753            }
2754            Project { input, outputs: _ } => f(input)?,
2755            Map { input, scalars: _ } => {
2756                f(input)?;
2757            }
2758            CallTable { func: _, exprs: _ } => (),
2759            Filter {
2760                input,
2761                predicates: _,
2762            } => {
2763                f(input)?;
2764            }
2765            Join {
2766                left,
2767                right,
2768                on: _,
2769                kind: _,
2770            } => {
2771                f(left)?;
2772                f(right)?;
2773            }
2774            Reduce {
2775                input,
2776                group_key: _,
2777                aggregates: _,
2778                expected_group_size: _,
2779            } => {
2780                f(input)?;
2781            }
2782            Distinct { input }
2783            | TopK {
2784                input,
2785                group_key: _,
2786                order_key: _,
2787                limit: _,
2788                offset: _,
2789                expected_group_size: _,
2790            }
2791            | Negate { input }
2792            | Threshold { input } => {
2793                f(input)?;
2794            }
2795            Union { base, inputs } => {
2796                f(base)?;
2797                for input in inputs {
2798                    f(input)?;
2799                }
2800            }
2801        }
2802        Ok(())
2803    }
2804}
2805
2806impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2807    fn visit_children<F>(&self, mut f: F)
2808    where
2809        F: FnMut(&HirScalarExpr),
2810    {
2811        use HirRelationExpr::*;
2812        match self {
2813            Constant { rows: _, typ: _ }
2814            | Get { id: _, typ: _ }
2815            | Let {
2816                name: _,
2817                id: _,
2818                value: _,
2819                body: _,
2820            }
2821            | LetRec {
2822                limit: _,
2823                bindings: _,
2824                body: _,
2825            }
2826            | Project {
2827                input: _,
2828                outputs: _,
2829            } => (),
2830            Map { input: _, scalars } => {
2831                for scalar in scalars {
2832                    f(scalar);
2833                }
2834            }
2835            CallTable { func: _, exprs } => {
2836                for expr in exprs {
2837                    f(expr);
2838                }
2839            }
2840            Filter {
2841                input: _,
2842                predicates,
2843            } => {
2844                for predicate in predicates {
2845                    f(predicate);
2846                }
2847            }
2848            Join {
2849                left: _,
2850                right: _,
2851                on,
2852                kind: _,
2853            } => f(on),
2854            Reduce {
2855                input: _,
2856                group_key: _,
2857                aggregates,
2858                expected_group_size: _,
2859            } => {
2860                for aggregate in aggregates {
2861                    f(aggregate.expr.as_ref());
2862                }
2863            }
2864            TopK {
2865                input: _,
2866                group_key: _,
2867                order_key: _,
2868                limit,
2869                offset,
2870                expected_group_size: _,
2871            } => {
2872                if let Some(limit) = limit {
2873                    f(limit)
2874                }
2875                f(offset)
2876            }
2877            Distinct { input: _ }
2878            | Negate { input: _ }
2879            | Threshold { input: _ }
2880            | Union { base: _, inputs: _ } => (),
2881        }
2882    }
2883
2884    fn visit_mut_children<F>(&mut self, mut f: F)
2885    where
2886        F: FnMut(&mut HirScalarExpr),
2887    {
2888        use HirRelationExpr::*;
2889        match self {
2890            Constant { rows: _, typ: _ }
2891            | Get { id: _, typ: _ }
2892            | Let {
2893                name: _,
2894                id: _,
2895                value: _,
2896                body: _,
2897            }
2898            | LetRec {
2899                limit: _,
2900                bindings: _,
2901                body: _,
2902            }
2903            | Project {
2904                input: _,
2905                outputs: _,
2906            } => (),
2907            Map { input: _, scalars } => {
2908                for scalar in scalars {
2909                    f(scalar);
2910                }
2911            }
2912            CallTable { func: _, exprs } => {
2913                for expr in exprs {
2914                    f(expr);
2915                }
2916            }
2917            Filter {
2918                input: _,
2919                predicates,
2920            } => {
2921                for predicate in predicates {
2922                    f(predicate);
2923                }
2924            }
2925            Join {
2926                left: _,
2927                right: _,
2928                on,
2929                kind: _,
2930            } => f(on),
2931            Reduce {
2932                input: _,
2933                group_key: _,
2934                aggregates,
2935                expected_group_size: _,
2936            } => {
2937                for aggregate in aggregates {
2938                    f(aggregate.expr.as_mut());
2939                }
2940            }
2941            TopK {
2942                input: _,
2943                group_key: _,
2944                order_key: _,
2945                limit,
2946                offset,
2947                expected_group_size: _,
2948            } => {
2949                if let Some(limit) = limit {
2950                    f(limit)
2951                }
2952                f(offset)
2953            }
2954            Distinct { input: _ }
2955            | Negate { input: _ }
2956            | Threshold { input: _ }
2957            | Union { base: _, inputs: _ } => (),
2958        }
2959    }
2960
2961    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2962    where
2963        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2964        E: From<RecursionLimitError>,
2965    {
2966        use HirRelationExpr::*;
2967        match self {
2968            Constant { rows: _, typ: _ }
2969            | Get { id: _, typ: _ }
2970            | Let {
2971                name: _,
2972                id: _,
2973                value: _,
2974                body: _,
2975            }
2976            | LetRec {
2977                limit: _,
2978                bindings: _,
2979                body: _,
2980            }
2981            | Project {
2982                input: _,
2983                outputs: _,
2984            } => (),
2985            Map { input: _, scalars } => {
2986                for scalar in scalars {
2987                    f(scalar)?;
2988                }
2989            }
2990            CallTable { func: _, exprs } => {
2991                for expr in exprs {
2992                    f(expr)?;
2993                }
2994            }
2995            Filter {
2996                input: _,
2997                predicates,
2998            } => {
2999                for predicate in predicates {
3000                    f(predicate)?;
3001                }
3002            }
3003            Join {
3004                left: _,
3005                right: _,
3006                on,
3007                kind: _,
3008            } => f(on)?,
3009            Reduce {
3010                input: _,
3011                group_key: _,
3012                aggregates,
3013                expected_group_size: _,
3014            } => {
3015                for aggregate in aggregates {
3016                    f(aggregate.expr.as_ref())?;
3017                }
3018            }
3019            TopK {
3020                input: _,
3021                group_key: _,
3022                order_key: _,
3023                limit,
3024                offset,
3025                expected_group_size: _,
3026            } => {
3027                if let Some(limit) = limit {
3028                    f(limit)?
3029                }
3030                f(offset)?
3031            }
3032            Distinct { input: _ }
3033            | Negate { input: _ }
3034            | Threshold { input: _ }
3035            | Union { base: _, inputs: _ } => (),
3036        }
3037        Ok(())
3038    }
3039
3040    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3041    where
3042        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3043        E: From<RecursionLimitError>,
3044    {
3045        use HirRelationExpr::*;
3046        match self {
3047            Constant { rows: _, typ: _ }
3048            | Get { id: _, typ: _ }
3049            | Let {
3050                name: _,
3051                id: _,
3052                value: _,
3053                body: _,
3054            }
3055            | LetRec {
3056                limit: _,
3057                bindings: _,
3058                body: _,
3059            }
3060            | Project {
3061                input: _,
3062                outputs: _,
3063            } => (),
3064            Map { input: _, scalars } => {
3065                for scalar in scalars {
3066                    f(scalar)?;
3067                }
3068            }
3069            CallTable { func: _, exprs } => {
3070                for expr in exprs {
3071                    f(expr)?;
3072                }
3073            }
3074            Filter {
3075                input: _,
3076                predicates,
3077            } => {
3078                for predicate in predicates {
3079                    f(predicate)?;
3080                }
3081            }
3082            Join {
3083                left: _,
3084                right: _,
3085                on,
3086                kind: _,
3087            } => f(on)?,
3088            Reduce {
3089                input: _,
3090                group_key: _,
3091                aggregates,
3092                expected_group_size: _,
3093            } => {
3094                for aggregate in aggregates {
3095                    f(aggregate.expr.as_mut())?;
3096                }
3097            }
3098            TopK {
3099                input: _,
3100                group_key: _,
3101                order_key: _,
3102                limit,
3103                offset,
3104                expected_group_size: _,
3105            } => {
3106                if let Some(limit) = limit {
3107                    f(limit)?
3108                }
3109                f(offset)?
3110            }
3111            Distinct { input: _ }
3112            | Negate { input: _ }
3113            | Threshold { input: _ }
3114            | Union { base: _, inputs: _ } => (),
3115        }
3116        Ok(())
3117    }
3118}
3119
3120impl HirScalarExpr {
3121    pub fn name(&self) -> Option<Arc<str>> {
3122        use HirScalarExpr::*;
3123        match self {
3124            Column(_, name)
3125            | Parameter(_, name)
3126            | Literal(_, _, name)
3127            | CallUnmaterializable(_, name)
3128            | CallUnary { name, .. }
3129            | CallBinary { name, .. }
3130            | CallVariadic { name, .. }
3131            | If { name, .. }
3132            | Exists(_, name)
3133            | Select(_, name)
3134            | Windowing(_, name) => name.0.clone(),
3135        }
3136    }
3137
3138    /// Replaces any parameter references in the expression with the
3139    /// corresponding datum in `params`.
3140    pub fn bind_parameters(
3141        &mut self,
3142        scx: &StatementContext,
3143        lifetime: QueryLifetime,
3144        params: &Params,
3145    ) -> Result<(), PlanError> {
3146        #[allow(deprecated)]
3147        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3148            if let HirScalarExpr::Parameter(n, name) = e {
3149                let datum = match params.datums.iter().nth(*n - 1) {
3150                    None => return Err(PlanError::UnknownParameter(*n)),
3151                    Some(datum) => datum,
3152                };
3153                let scalar_type = &params.execute_types[*n - 1];
3154                let row = Row::pack([datum]);
3155                let column_type = scalar_type.clone().nullable(datum.is_null());
3156
3157                let name = if let Some(name) = &name.0 {
3158                    Some(Arc::clone(name))
3159                } else {
3160                    Some(Arc::from(format!("${n}")))
3161                };
3162
3163                let qcx = QueryContext::root(scx, lifetime);
3164                let ecx = execute_expr_context(&qcx);
3165
3166                *e = plan_cast(
3167                    &ecx,
3168                    *EXECUTE_CAST_CONTEXT,
3169                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3170                    &params.expected_types[*n - 1],
3171                )
3172                .expect("checked in plan_params");
3173            }
3174            Ok(())
3175        })
3176    }
3177
3178    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3179    /// replaced with the corresponding expression fragment from `params` rather
3180    /// than a datum.
3181    ///
3182    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3183    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3184    /// in `self` that refer to invalid indices of `params` will cause a panic.
3185    ///
3186    /// Column references in parameters will be corrected to account for the
3187    /// depth at which they are spliced.
3188    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3189        #[allow(deprecated)]
3190        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3191                                                        e: &mut HirScalarExpr|
3192         -> Result<(), ()> {
3193            if let HirScalarExpr::Parameter(i, _name) = e {
3194                *e = params[*i - 1].clone();
3195                // Correct any column references in the parameter expression for
3196                // its new depth.
3197                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3198                    if col.level >= d {
3199                        col.level += depth
3200                    }
3201                });
3202            }
3203            Ok(())
3204        });
3205    }
3206
3207    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3208    pub fn contains_temporal(&self) -> bool {
3209        let mut contains = false;
3210        #[allow(deprecated)]
3211        self.visit_post_nolimit(&mut |e| {
3212            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3213                contains = true;
3214            }
3215        });
3216        contains
3217    }
3218
3219    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3220    pub fn contains_unmaterializable(&self) -> bool {
3221        let mut contains = false;
3222        #[allow(deprecated)]
3223        self.visit_post_nolimit(&mut |e| {
3224            if let Self::CallUnmaterializable(_, _) = e {
3225                contains = true;
3226            }
3227        });
3228        contains
3229    }
3230
3231    /// Constructs an unnamed column reference in the current scope.
3232    /// Use [`HirScalarExpr::named_column`] when a name is known.
3233    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3234    pub fn column(index: usize) -> HirScalarExpr {
3235        HirScalarExpr::Column(
3236            ColumnRef {
3237                level: 0,
3238                column: index,
3239            },
3240            TreatAsEqual(None),
3241        )
3242    }
3243
3244    /// Constructs an unnamed column reference.
3245    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3246        HirScalarExpr::Column(cr, TreatAsEqual(None))
3247    }
3248
3249    /// Constructs a named column reference.
3250    /// Names are interned by a `NameManager`.
3251    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3252        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3253    }
3254
3255    pub fn parameter(n: usize) -> HirScalarExpr {
3256        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3257    }
3258
3259    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3260        let col_type = scalar_type.nullable(datum.is_null());
3261        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3262        let row = Row::pack([datum]);
3263        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3264    }
3265
3266    pub fn literal_true() -> HirScalarExpr {
3267        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3268    }
3269
3270    pub fn literal_false() -> HirScalarExpr {
3271        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3272    }
3273
3274    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3275        HirScalarExpr::literal(Datum::Null, scalar_type)
3276    }
3277
3278    pub fn literal_1d_array(
3279        datums: Vec<Datum>,
3280        element_scalar_type: SqlScalarType,
3281    ) -> Result<HirScalarExpr, PlanError> {
3282        let scalar_type = match element_scalar_type {
3283            SqlScalarType::Array(_) => {
3284                sql_bail!("cannot build array from array type");
3285            }
3286            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3287        };
3288
3289        let mut row = Row::default();
3290        row.packer()
3291            .try_push_array(
3292                &[ArrayDimension {
3293                    lower_bound: 1,
3294                    length: datums.len(),
3295                }],
3296                datums,
3297            )
3298            .expect("array constructed to be valid");
3299
3300        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3301    }
3302
3303    pub fn as_literal(&self) -> Option<Datum<'_>> {
3304        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3305            Some(row.unpack_first())
3306        } else {
3307            None
3308        }
3309    }
3310
3311    pub fn is_literal_true(&self) -> bool {
3312        Some(Datum::True) == self.as_literal()
3313    }
3314
3315    pub fn is_literal_false(&self) -> bool {
3316        Some(Datum::False) == self.as_literal()
3317    }
3318
3319    pub fn is_literal_null(&self) -> bool {
3320        Some(Datum::Null) == self.as_literal()
3321    }
3322
3323    /// Return true iff `self` consists only of literals, materializable function calls, and
3324    /// if-else statements.
3325    pub fn is_constant(&self) -> bool {
3326        let mut worklist = vec![self];
3327        while let Some(expr) = worklist.pop() {
3328            match expr {
3329                Self::Literal(..) => {
3330                    // leaf node, do nothing
3331                }
3332                Self::CallUnary { expr, .. } => {
3333                    worklist.push(expr);
3334                }
3335                Self::CallBinary {
3336                    func: _,
3337                    expr1,
3338                    expr2,
3339                    name: _,
3340                } => {
3341                    worklist.push(expr1);
3342                    worklist.push(expr2);
3343                }
3344                Self::CallVariadic {
3345                    func: _,
3346                    exprs,
3347                    name: _,
3348                } => {
3349                    worklist.extend(exprs.iter());
3350                }
3351                // (CallUnmaterializable is not allowed)
3352                Self::If {
3353                    cond,
3354                    then,
3355                    els,
3356                    name: _,
3357                } => {
3358                    worklist.push(cond);
3359                    worklist.push(then);
3360                    worklist.push(els);
3361                }
3362                _ => {
3363                    return false; // Any other node makes `self` non-constant.
3364                }
3365            }
3366        }
3367        true
3368    }
3369
3370    pub fn call_unary(self, func: UnaryFunc) -> Self {
3371        HirScalarExpr::CallUnary {
3372            func,
3373            expr: Box::new(self),
3374            name: NameMetadata::default(),
3375        }
3376    }
3377
3378    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3379        HirScalarExpr::CallBinary {
3380            func: func.into(),
3381            expr1: Box::new(self),
3382            expr2: Box::new(other),
3383            name: NameMetadata::default(),
3384        }
3385    }
3386
3387    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3388        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3389    }
3390
3391    pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3392        HirScalarExpr::CallVariadic {
3393            func,
3394            exprs,
3395            name: NameMetadata::default(),
3396        }
3397    }
3398
3399    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3400        HirScalarExpr::If {
3401            cond: Box::new(cond),
3402            then: Box::new(then),
3403            els: Box::new(els),
3404            name: NameMetadata::default(),
3405        }
3406    }
3407
3408    pub fn windowing(expr: WindowExpr) -> Self {
3409        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3410    }
3411
3412    pub fn or(self, other: Self) -> Self {
3413        HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3414    }
3415
3416    pub fn and(self, other: Self) -> Self {
3417        HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3418    }
3419
3420    pub fn not(self) -> Self {
3421        self.call_unary(UnaryFunc::Not(func::Not))
3422    }
3423
3424    pub fn call_is_null(self) -> Self {
3425        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3426    }
3427
3428    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3429    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3430        match args.len() {
3431            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3432            1 => args.swap_remove(0),
3433            _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3434        }
3435    }
3436
3437    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3438    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3439        match args.len() {
3440            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3441            1 => args.swap_remove(0),
3442            _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3443        }
3444    }
3445
3446    pub fn take(&mut self) -> Self {
3447        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3448    }
3449
3450    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3451    /// Visits the column references in this scalar expression.
3452    ///
3453    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3454    /// which will be incremented with each subquery entered and presented to the supplied
3455    /// function `f`.
3456    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3457    where
3458        F: FnMut(usize, &ColumnRef),
3459    {
3460        #[allow(deprecated)]
3461        let _ = self.visit_recursively(depth, &mut |depth: usize,
3462                                                    e: &HirScalarExpr|
3463         -> Result<(), ()> {
3464            if let HirScalarExpr::Column(col, _name) = e {
3465                f(depth, col)
3466            }
3467            Ok(())
3468        });
3469    }
3470
3471    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3472    /// Like `visit_columns`, but permits mutating the column references.
3473    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3474    where
3475        F: FnMut(usize, &mut ColumnRef),
3476    {
3477        #[allow(deprecated)]
3478        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3479                                                        e: &mut HirScalarExpr|
3480         -> Result<(), ()> {
3481            if let HirScalarExpr::Column(col, _name) = e {
3482                f(depth, col)
3483            }
3484            Ok(())
3485        });
3486    }
3487
3488    /// Visits those column references in this scalar expression that refer to the root
3489    /// level. These include column references that are at the root level, as well as column
3490    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3491    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3492    /// "root level" to be `self`'s level.)
3493    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3494    where
3495        F: FnMut(usize),
3496    {
3497        #[allow(deprecated)]
3498        let _ = self.visit_recursively(0, &mut |depth: usize,
3499                                                e: &HirScalarExpr|
3500         -> Result<(), ()> {
3501            if let HirScalarExpr::Column(col, _name) = e {
3502                if col.level == depth {
3503                    f(col.column)
3504                }
3505            }
3506            Ok(())
3507        });
3508    }
3509
3510    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3511    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3512    where
3513        F: FnMut(&mut usize),
3514    {
3515        #[allow(deprecated)]
3516        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3517                                                    e: &mut HirScalarExpr|
3518         -> Result<(), ()> {
3519            if let HirScalarExpr::Column(col, _name) = e {
3520                if col.level == depth {
3521                    f(&mut col.column)
3522                }
3523            }
3524            Ok(())
3525        });
3526    }
3527
3528    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3529    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3530    /// in them. It takes the current depth of the expression and increases it when
3531    /// entering a subquery.
3532    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3533    where
3534        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3535    {
3536        match self {
3537            HirScalarExpr::Literal(..)
3538            | HirScalarExpr::Parameter(..)
3539            | HirScalarExpr::CallUnmaterializable(..)
3540            | HirScalarExpr::Column(..) => (),
3541            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3542            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3543                expr1.visit_recursively(depth, f)?;
3544                expr2.visit_recursively(depth, f)?;
3545            }
3546            HirScalarExpr::CallVariadic { exprs, .. } => {
3547                for expr in exprs {
3548                    expr.visit_recursively(depth, f)?;
3549                }
3550            }
3551            HirScalarExpr::If {
3552                cond,
3553                then,
3554                els,
3555                name: _,
3556            } => {
3557                cond.visit_recursively(depth, f)?;
3558                then.visit_recursively(depth, f)?;
3559                els.visit_recursively(depth, f)?;
3560            }
3561            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3562                #[allow(deprecated)]
3563                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3564                    e.visit_recursively(depth, f)
3565                })?;
3566            }
3567            HirScalarExpr::Windowing(expr, _name) => {
3568                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3569            }
3570        }
3571        f(depth, self)
3572    }
3573
3574    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3575    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3576    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3577    where
3578        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3579    {
3580        match self {
3581            HirScalarExpr::Literal(..)
3582            | HirScalarExpr::Parameter(..)
3583            | HirScalarExpr::CallUnmaterializable(..)
3584            | HirScalarExpr::Column(..) => (),
3585            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3586            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3587                expr1.visit_recursively_mut(depth, f)?;
3588                expr2.visit_recursively_mut(depth, f)?;
3589            }
3590            HirScalarExpr::CallVariadic { exprs, .. } => {
3591                for expr in exprs {
3592                    expr.visit_recursively_mut(depth, f)?;
3593                }
3594            }
3595            HirScalarExpr::If {
3596                cond,
3597                then,
3598                els,
3599                name: _,
3600            } => {
3601                cond.visit_recursively_mut(depth, f)?;
3602                then.visit_recursively_mut(depth, f)?;
3603                els.visit_recursively_mut(depth, f)?;
3604            }
3605            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3606                #[allow(deprecated)]
3607                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3608                    e.visit_recursively_mut(depth, f)
3609                })?;
3610            }
3611            HirScalarExpr::Windowing(expr, _name) => {
3612                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3613            }
3614        }
3615        f(depth, self)
3616    }
3617
3618    /// Attempts to simplify self into a literal.
3619    ///
3620    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3621    /// an evaluation error occurs during simplification, or if self contains
3622    /// - a subquery
3623    /// - a column reference to an outer level
3624    /// - a parameter
3625    /// - a window function call
3626    fn simplify_to_literal(self) -> Option<Row> {
3627        let mut expr = self
3628            .lower_uncorrelated(crate::plan::lowering::Config::default())
3629            .ok()?;
3630        expr.reduce(&[]);
3631        match expr {
3632            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3633            _ => None,
3634        }
3635    }
3636
3637    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3638    /// or an evaluation error occurs during simplification), it returns
3639    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3640    ///
3641    /// The returned error is an _internal_ error if the expression contains
3642    /// - a subquery
3643    /// - a column reference to an outer level
3644    /// - a parameter
3645    /// - a window function call
3646    ///
3647    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3648    /// msg.
3649    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3650        let mut expr = self
3651            .lower_uncorrelated(crate::plan::lowering::Config::default())
3652            .map_err(|err| {
3653                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3654            })?;
3655        expr.reduce(&[]);
3656        match expr {
3657            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3658            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3659                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3660            ),
3661            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3662                "Not a constant".to_string(),
3663            )),
3664        }
3665    }
3666
3667    /// Attempts to simplify this expression to a literal 64-bit integer.
3668    ///
3669    /// Returns `None` if this expression cannot be simplified, e.g. because it
3670    /// contains non-literal values.
3671    ///
3672    /// # Panics
3673    ///
3674    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3675    pub fn into_literal_int64(self) -> Option<i64> {
3676        self.simplify_to_literal().and_then(|row| {
3677            let datum = row.unpack_first();
3678            if datum.is_null() {
3679                None
3680            } else {
3681                Some(datum.unwrap_int64())
3682            }
3683        })
3684    }
3685
3686    /// Attempts to simplify this expression to a literal string.
3687    ///
3688    /// Returns `None` if this expression cannot be simplified, e.g. because it
3689    /// contains non-literal values.
3690    ///
3691    /// # Panics
3692    ///
3693    /// Panics if this expression does not have type [`SqlScalarType::String`].
3694    pub fn into_literal_string(self) -> Option<String> {
3695        self.simplify_to_literal().and_then(|row| {
3696            let datum = row.unpack_first();
3697            if datum.is_null() {
3698                None
3699            } else {
3700                Some(datum.unwrap_str().to_owned())
3701            }
3702        })
3703    }
3704
3705    /// Attempts to simplify this expression to a literal MzTimestamp.
3706    ///
3707    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3708    /// simplified, e.g. because it contains non-literal values or a cast fails.
3709    ///
3710    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3711    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3712    /// See `try_into_literal_int64` as an example.
3713    ///
3714    /// # Panics
3715    ///
3716    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3717    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3718        self.simplify_to_literal().and_then(|row| {
3719            let datum = row.unpack_first();
3720            if datum.is_null() {
3721                None
3722            } else {
3723                Some(datum.unwrap_mz_timestamp())
3724            }
3725        })
3726    }
3727
3728    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3729    /// returns it as an i64.
3730    ///
3731    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3732    /// - it's not a constant expression (as determined by `is_constant`)
3733    /// - evaluates to null
3734    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3735    ///
3736    /// # Panics
3737    ///
3738    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3739    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3740        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3741        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3742        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3743        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3744        // preconditions also of all the other into_literal_... functions.)
3745        if !self.is_constant() {
3746            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3747                "Expected a constant expression, got {}",
3748                self
3749            )));
3750        }
3751        self.clone()
3752            .simplify_to_literal_with_result()
3753            .and_then(|row| {
3754                let datum = row.unpack_first();
3755                if datum.is_null() {
3756                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3757                        "Expected an expression that evaluates to a non-null value, got {}",
3758                        self
3759                    )))
3760                } else {
3761                    Ok(datum.unwrap_int64())
3762                }
3763            })
3764    }
3765
3766    pub fn contains_parameters(&self) -> bool {
3767        let mut contains_parameters = false;
3768        #[allow(deprecated)]
3769        let _ = self.visit_recursively(0, &mut |_depth: usize,
3770                                                expr: &HirScalarExpr|
3771         -> Result<(), ()> {
3772            if let HirScalarExpr::Parameter(..) = expr {
3773                contains_parameters = true;
3774            }
3775            Ok(())
3776        });
3777        contains_parameters
3778    }
3779}
3780
3781impl VisitChildren<Self> for HirScalarExpr {
3782    fn visit_children<F>(&self, mut f: F)
3783    where
3784        F: FnMut(&Self),
3785    {
3786        use HirScalarExpr::*;
3787        match self {
3788            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3789            CallUnary { expr, .. } => f(expr),
3790            CallBinary { expr1, expr2, .. } => {
3791                f(expr1);
3792                f(expr2);
3793            }
3794            CallVariadic { exprs, .. } => {
3795                for expr in exprs {
3796                    f(expr);
3797                }
3798            }
3799            If {
3800                cond,
3801                then,
3802                els,
3803                name: _,
3804            } => {
3805                f(cond);
3806                f(then);
3807                f(els);
3808            }
3809            Exists(..) | Select(..) => (),
3810            Windowing(expr, _name) => expr.visit_children(f),
3811        }
3812    }
3813
3814    fn visit_mut_children<F>(&mut self, mut f: F)
3815    where
3816        F: FnMut(&mut Self),
3817    {
3818        use HirScalarExpr::*;
3819        match self {
3820            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3821            CallUnary { expr, .. } => f(expr),
3822            CallBinary { expr1, expr2, .. } => {
3823                f(expr1);
3824                f(expr2);
3825            }
3826            CallVariadic { exprs, .. } => {
3827                for expr in exprs {
3828                    f(expr);
3829                }
3830            }
3831            If {
3832                cond,
3833                then,
3834                els,
3835                name: _,
3836            } => {
3837                f(cond);
3838                f(then);
3839                f(els);
3840            }
3841            Exists(..) | Select(..) => (),
3842            Windowing(expr, _name) => expr.visit_mut_children(f),
3843        }
3844    }
3845
3846    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3847    where
3848        F: FnMut(&Self) -> Result<(), E>,
3849        E: From<RecursionLimitError>,
3850    {
3851        use HirScalarExpr::*;
3852        match self {
3853            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3854            CallUnary { expr, .. } => f(expr)?,
3855            CallBinary { expr1, expr2, .. } => {
3856                f(expr1)?;
3857                f(expr2)?;
3858            }
3859            CallVariadic { exprs, .. } => {
3860                for expr in exprs {
3861                    f(expr)?;
3862                }
3863            }
3864            If {
3865                cond,
3866                then,
3867                els,
3868                name: _,
3869            } => {
3870                f(cond)?;
3871                f(then)?;
3872                f(els)?;
3873            }
3874            Exists(..) | Select(..) => (),
3875            Windowing(expr, _name) => expr.try_visit_children(f)?,
3876        }
3877        Ok(())
3878    }
3879
3880    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3881    where
3882        F: FnMut(&mut Self) -> Result<(), E>,
3883        E: From<RecursionLimitError>,
3884    {
3885        use HirScalarExpr::*;
3886        match self {
3887            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3888            CallUnary { expr, .. } => f(expr)?,
3889            CallBinary { expr1, expr2, .. } => {
3890                f(expr1)?;
3891                f(expr2)?;
3892            }
3893            CallVariadic { exprs, .. } => {
3894                for expr in exprs {
3895                    f(expr)?;
3896                }
3897            }
3898            If {
3899                cond,
3900                then,
3901                els,
3902                name: _,
3903            } => {
3904                f(cond)?;
3905                f(then)?;
3906                f(els)?;
3907            }
3908            Exists(..) | Select(..) => (),
3909            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3910        }
3911        Ok(())
3912    }
3913}
3914
3915impl AbstractExpr for HirScalarExpr {
3916    type Type = SqlColumnType;
3917
3918    fn typ(
3919        &self,
3920        outers: &[SqlRelationType],
3921        inner: &SqlRelationType,
3922        params: &BTreeMap<usize, SqlScalarType>,
3923    ) -> Self::Type {
3924        stack::maybe_grow(|| match self {
3925            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3926                if *level == 0 {
3927                    inner.column_types[*column].clone()
3928                } else {
3929                    outers[*level - 1].column_types[*column].clone()
3930                }
3931            }
3932            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3933            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3934            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3935            HirScalarExpr::CallUnary {
3936                expr,
3937                func,
3938                name: _,
3939            } => func.output_type(expr.typ(outers, inner, params)),
3940            HirScalarExpr::CallBinary {
3941                expr1,
3942                expr2,
3943                func,
3944                name: _,
3945            } => func.output_type(&[
3946                expr1.typ(outers, inner, params),
3947                expr2.typ(outers, inner, params),
3948            ]),
3949            HirScalarExpr::CallVariadic {
3950                exprs,
3951                func,
3952                name: _,
3953            } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3954            HirScalarExpr::If {
3955                cond: _,
3956                then,
3957                els,
3958                name: _,
3959            } => {
3960                let then_type = then.typ(outers, inner, params);
3961                let else_type = els.typ(outers, inner, params);
3962                then_type.sql_union(&else_type).unwrap() // HIR deliberately not using `union`
3963            }
3964            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3965            HirScalarExpr::Select(expr, _name) => {
3966                let mut outers = outers.to_vec();
3967                outers.insert(0, inner.clone());
3968                expr.typ(&outers, params)
3969                    .column_types
3970                    .into_element()
3971                    .nullable(true)
3972            }
3973            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3974        })
3975    }
3976}
3977
3978impl AggregateExpr {
3979    pub fn typ(
3980        &self,
3981        outers: &[SqlRelationType],
3982        inner: &SqlRelationType,
3983        params: &BTreeMap<usize, SqlScalarType>,
3984    ) -> SqlColumnType {
3985        self.func.output_type(self.expr.typ(outers, inner, params))
3986    }
3987
3988    /// Returns whether the expression is COUNT(*) or not.  Note that
3989    /// when we define the count builtin in sql::func, we convert
3990    /// COUNT(*) to COUNT(true), making it indistinguishable from
3991    /// literal COUNT(true), but we prefer to consider this as the
3992    /// former.
3993    ///
3994    /// (MIR has the same `is_count_asterisk`.)
3995    pub fn is_count_asterisk(&self) -> bool {
3996        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3997    }
3998}