Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25use mz_expr::func::variadic::{And, Or};
26pub use mz_expr::{
27    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
28};
29use mz_ore::collections::CollectionExt;
30use mz_ore::error::ErrorExt;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{
41    EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context, offset_into_value,
42};
43use crate::plan::typeconv::{self, CastContext, plan_cast};
44use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
45
46use super::plan_utils::GroupSizeHints;
47
48#[allow(missing_debug_implementations)]
49pub struct Hir;
50
51impl IR for Hir {
52    type Relation = HirRelationExpr;
53    type Scalar = HirScalarExpr;
54}
55
56impl AlgExcept for Hir {
57    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
58        if *all {
59            let rhs = rhs.negate();
60            HirRelationExpr::union(lhs, rhs).threshold()
61        } else {
62            let lhs = lhs.distinct();
63            let rhs = rhs.distinct().negate();
64            HirRelationExpr::union(lhs, rhs).threshold()
65        }
66    }
67
68    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
69        let mut result = None;
70
71        use HirRelationExpr::*;
72        if let Threshold { input } = expr {
73            if let Union { base: lhs, inputs } = input.as_ref() {
74                if let [rhs] = &inputs[..] {
75                    if let Negate { input: rhs } = rhs {
76                        match (lhs.as_ref(), rhs.as_ref()) {
77                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
78                                let all = false;
79                                let lhs = lhs.as_ref();
80                                let rhs = rhs.as_ref();
81                                result = Some(Except { all, lhs, rhs })
82                            }
83                            (lhs, rhs) => {
84                                let all = true;
85                                result = Some(Except { all, lhs, rhs })
86                            }
87                        }
88                    }
89                }
90            }
91        }
92
93        result
94    }
95}
96
97#[derive(
98    Debug,
99    Clone,
100    PartialEq,
101    Eq,
102    PartialOrd,
103    Ord,
104    Hash,
105    Serialize,
106    Deserialize
107)]
108/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
109pub enum HirRelationExpr {
110    Constant {
111        rows: Vec<Row>,
112        typ: SqlRelationType,
113    },
114    Get {
115        id: mz_expr::Id,
116        typ: SqlRelationType,
117    },
118    /// Mutually recursive CTE
119    LetRec {
120        /// Maximum number of iterations to evaluate. If None, then there is no limit.
121        limit: Option<LetRecLimit>,
122        /// List of bindings all of which are in scope of each other.
123        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
124        /// Result of the AST node.
125        body: Box<HirRelationExpr>,
126    },
127    /// CTE
128    Let {
129        name: String,
130        /// The identifier to be used in `Get` variants to retrieve `value`.
131        id: mz_expr::LocalId,
132        /// The collection to be bound to `name`.
133        value: Box<HirRelationExpr>,
134        /// The result of the `Let`, evaluated with `name` bound to `value`.
135        body: Box<HirRelationExpr>,
136    },
137    Project {
138        input: Box<HirRelationExpr>,
139        outputs: Vec<usize>,
140    },
141    Map {
142        input: Box<HirRelationExpr>,
143        scalars: Vec<HirScalarExpr>,
144    },
145    CallTable {
146        func: TableFunc,
147        exprs: Vec<HirScalarExpr>,
148    },
149    Filter {
150        input: Box<HirRelationExpr>,
151        predicates: Vec<HirScalarExpr>,
152    },
153    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
154    /// joins away into more primitive exprs
155    Join {
156        left: Box<HirRelationExpr>,
157        right: Box<HirRelationExpr>,
158        on: HirScalarExpr,
159        kind: JoinKind,
160    },
161    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
162    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
163    /// rows
164    Reduce {
165        input: Box<HirRelationExpr>,
166        group_key: Vec<usize>,
167        aggregates: Vec<AggregateExpr>,
168        expected_group_size: Option<u64>,
169    },
170    Distinct {
171        input: Box<HirRelationExpr>,
172    },
173    /// Groups and orders within each group, limiting output.
174    TopK {
175        /// The source collection.
176        input: Box<HirRelationExpr>,
177        /// Column indices used to form groups.
178        group_key: Vec<usize>,
179        /// Column indices used to order rows within groups.
180        order_key: Vec<ColumnOrder>,
181        /// Number of records to retain.
182        /// It is of SqlScalarType::Int64.
183        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
184        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
185        /// is better for Postgres compat. This is because if there is a $1 here, then when external
186        /// tools `describe` the prepared statement, they discover this type. If what they find
187        /// were UInt64, then they might have trouble calling the prepared statement, because the
188        /// unsigned types are non-standard, and also don't exist even in Postgres.)
189        limit: Option<HirScalarExpr>,
190        /// Number of records to skip.
191        /// It is of SqlScalarType::Int64.
192        /// This can contain parameters at first, but by the time we reach lowering, this should
193        /// already be simply a Literal.
194        offset: HirScalarExpr,
195        /// User-supplied hint: how many rows will have the same group key.
196        expected_group_size: Option<u64>,
197    },
198    Negate {
199        input: Box<HirRelationExpr>,
200    },
201    /// Keep rows from a dataflow where the row counts are positive.
202    Threshold {
203        input: Box<HirRelationExpr>,
204    },
205    Union {
206        base: Box<HirRelationExpr>,
207        inputs: Vec<HirRelationExpr>,
208    },
209}
210
211/// Stored column metadata.
212pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
213
214#[derive(
215    Debug,
216    Clone,
217    PartialEq,
218    Eq,
219    PartialOrd,
220    Ord,
221    Hash,
222    Serialize,
223    Deserialize
224)]
225/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
226pub enum HirScalarExpr {
227    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
228    /// variable could refer to a column of the current input, or to a column of an outer relation.
229    /// We use ColumnRef to denote the difference.
230    Column(ColumnRef, NameMetadata),
231    Parameter(usize, NameMetadata),
232    Literal(Row, SqlColumnType, NameMetadata),
233    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
234    CallUnary {
235        func: UnaryFunc,
236        expr: Box<HirScalarExpr>,
237        name: NameMetadata,
238    },
239    CallBinary {
240        func: BinaryFunc,
241        expr1: Box<HirScalarExpr>,
242        expr2: Box<HirScalarExpr>,
243        name: NameMetadata,
244    },
245    CallVariadic {
246        func: VariadicFunc,
247        exprs: Vec<HirScalarExpr>,
248        name: NameMetadata,
249    },
250    If {
251        cond: Box<HirScalarExpr>,
252        then: Box<HirScalarExpr>,
253        els: Box<HirScalarExpr>,
254        name: NameMetadata,
255    },
256    /// Returns true if `expr` returns any rows
257    Exists(Box<HirRelationExpr>, NameMetadata),
258    /// Given `expr` with arity 1. If expr returns:
259    /// * 0 rows, return NULL
260    /// * 1 row, return the value of that row
261    /// * >1 rows, we return an error
262    Select(Box<HirRelationExpr>, NameMetadata),
263    Windowing(WindowExpr, NameMetadata),
264}
265
266#[derive(
267    Debug,
268    Clone,
269    PartialEq,
270    Eq,
271    PartialOrd,
272    Ord,
273    Hash,
274    Serialize,
275    Deserialize
276)]
277/// Represents the invocation of a window function over an optional partitioning with an optional
278/// order.
279pub struct WindowExpr {
280    pub func: WindowExprType,
281    pub partition_by: Vec<HirScalarExpr>,
282    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
283    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
284    ///    above,
285    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
286    /// These are separated because they are used in different places: the outer `order_by` is used
287    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
288    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
289    /// (`WindowExpr` exists only in HIR, but not in MIR.)
290    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
291    /// lowering, and not to original input columns.
292    pub order_by: Vec<HirScalarExpr>,
293}
294
295impl WindowExpr {
296    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
297    where
298        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
299    {
300        #[allow(deprecated)]
301        self.func.visit_expressions(f)?;
302        for expr in self.partition_by.iter() {
303            f(expr)?;
304        }
305        for expr in self.order_by.iter() {
306            f(expr)?;
307        }
308        Ok(())
309    }
310
311    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
312    where
313        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
314    {
315        #[allow(deprecated)]
316        self.func.visit_expressions_mut(f)?;
317        for expr in self.partition_by.iter_mut() {
318            f(expr)?;
319        }
320        for expr in self.order_by.iter_mut() {
321            f(expr)?;
322        }
323        Ok(())
324    }
325}
326
327/// Yields the scalars in `func`'s arguments, plus those in `partition_by`
328/// and `order_by`; does not descend into them.
329impl VisitChildren<HirScalarExpr> for WindowExpr {
330    fn visit_children<F>(&self, mut f: F)
331    where
332        F: FnMut(&HirScalarExpr),
333    {
334        self.func.visit_children(&mut f);
335        for expr in self.partition_by.iter() {
336            f(expr);
337        }
338        for expr in self.order_by.iter() {
339            f(expr);
340        }
341    }
342
343    fn visit_mut_children<F>(&mut self, mut f: F)
344    where
345        F: FnMut(&mut HirScalarExpr),
346    {
347        self.func.visit_mut_children(&mut f);
348        for expr in self.partition_by.iter_mut() {
349            f(expr);
350        }
351        for expr in self.order_by.iter_mut() {
352            f(expr);
353        }
354    }
355
356    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
357    where
358        F: FnMut(&HirScalarExpr) -> Result<(), E>,
359    {
360        self.func.try_visit_children(&mut f)?;
361        for expr in self.partition_by.iter() {
362            f(expr)?;
363        }
364        for expr in self.order_by.iter() {
365            f(expr)?;
366        }
367        Ok(())
368    }
369
370    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
371    where
372        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
373    {
374        self.func.try_visit_mut_children(&mut f)?;
375        for expr in self.partition_by.iter_mut() {
376            f(expr)?;
377        }
378        for expr in self.order_by.iter_mut() {
379            f(expr)?;
380        }
381        Ok(())
382    }
383
384    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
385    where
386        HirScalarExpr: 'a,
387    {
388        self.func
389            .children()
390            .chain(self.partition_by.iter())
391            .chain(self.order_by.iter())
392    }
393
394    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
395    where
396        HirScalarExpr: 'a,
397    {
398        self.func
399            .children_mut()
400            .chain(self.partition_by.iter_mut())
401            .chain(self.order_by.iter_mut())
402    }
403}
404
405#[derive(
406    Debug,
407    Clone,
408    PartialEq,
409    Eq,
410    PartialOrd,
411    Ord,
412    Hash,
413    Serialize,
414    Deserialize
415)]
416/// A window function with its parameters.
417///
418/// There are three types of window functions:
419/// - scalar window functions, which return a different scalar value for each
420///   row within a partition that depends exclusively on the position of the row
421///   within the partition;
422/// - value window functions, which return a scalar value for each row within a
423///   partition that might be computed based on a single row, which is usually not
424///   the current row (e.g., previous or following row; first or last row of the
425///   partition);
426/// - aggregate window functions, which compute a traditional aggregation as a
427///   window function (e.g. `sum(x) OVER (...)`).
428///   (Aggregate window  functions can in some cases be computed by joining the
429///   input relation with a reduction over the same relation that computes the
430///   aggregation using the partition key as its grouping key, but we don't
431///   automatically do this currently.)
432pub enum WindowExprType {
433    Scalar(ScalarWindowExpr),
434    Value(ValueWindowExpr),
435    Aggregate(AggregateWindowExpr),
436}
437
438impl WindowExprType {
439    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
440    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
441    where
442        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
443    {
444        #[allow(deprecated)]
445        match self {
446            Self::Scalar(expr) => expr.visit_expressions(f),
447            Self::Value(expr) => expr.visit_expressions(f),
448            Self::Aggregate(expr) => expr.visit_expressions(f),
449        }
450    }
451
452    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
453    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
454    where
455        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
456    {
457        #[allow(deprecated)]
458        match self {
459            Self::Scalar(expr) => expr.visit_expressions_mut(f),
460            Self::Value(expr) => expr.visit_expressions_mut(f),
461            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
462        }
463    }
464
465    fn typ(
466        &self,
467        outers: &[SqlRelationType],
468        inner: &SqlRelationType,
469        params: &BTreeMap<usize, SqlScalarType>,
470    ) -> SqlColumnType {
471        match self {
472            Self::Scalar(expr) => expr.typ(outers, inner, params),
473            Self::Value(expr) => expr.typ(outers, inner, params),
474            Self::Aggregate(expr) => expr.typ(outers, inner, params),
475        }
476    }
477}
478
479/// Dispatches to the inner `Value` / `Aggregate` variant's scalar children;
480/// `Scalar` window functions have no scalar children.
481impl VisitChildren<HirScalarExpr> for WindowExprType {
482    fn visit_children<F>(&self, f: F)
483    where
484        F: FnMut(&HirScalarExpr),
485    {
486        match self {
487            Self::Scalar(_) => (),
488            Self::Value(expr) => expr.visit_children(f),
489            Self::Aggregate(expr) => expr.visit_children(f),
490        }
491    }
492
493    fn visit_mut_children<F>(&mut self, f: F)
494    where
495        F: FnMut(&mut HirScalarExpr),
496    {
497        match self {
498            Self::Scalar(_) => (),
499            Self::Value(expr) => expr.visit_mut_children(f),
500            Self::Aggregate(expr) => expr.visit_mut_children(f),
501        }
502    }
503
504    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
505    where
506        F: FnMut(&HirScalarExpr) -> Result<(), E>,
507    {
508        match self {
509            Self::Scalar(_) => Ok(()),
510            Self::Value(expr) => expr.try_visit_children(f),
511            Self::Aggregate(expr) => expr.try_visit_children(f),
512        }
513    }
514
515    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
516    where
517        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
518    {
519        match self {
520            Self::Scalar(_) => Ok(()),
521            Self::Value(expr) => expr.try_visit_mut_children(f),
522            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
523        }
524    }
525
526    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
527    where
528        HirScalarExpr: 'a,
529    {
530        match self {
531            Self::Scalar(_) => vec![],
532            Self::Value(expr) => expr.children().collect(),
533            Self::Aggregate(expr) => expr.children().collect(),
534        }
535        .into_iter()
536    }
537
538    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
539    where
540        HirScalarExpr: 'a,
541    {
542        match self {
543            Self::Scalar(_) => vec![],
544            Self::Value(expr) => expr.children_mut().collect(),
545            Self::Aggregate(expr) => expr.children_mut().collect(),
546        }
547        .into_iter()
548    }
549}
550
551#[derive(
552    Debug,
553    Clone,
554    PartialEq,
555    Eq,
556    PartialOrd,
557    Ord,
558    Hash,
559    Serialize,
560    Deserialize
561)]
562pub struct ScalarWindowExpr {
563    pub func: ScalarWindowFunc,
564    pub order_by: Vec<ColumnOrder>,
565}
566
567impl ScalarWindowExpr {
568    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
569    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
570    where
571        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
572    {
573        match self.func {
574            ScalarWindowFunc::RowNumber => {}
575            ScalarWindowFunc::Rank => {}
576            ScalarWindowFunc::DenseRank => {}
577        }
578        Ok(())
579    }
580
581    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
582    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
583    where
584        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
585    {
586        match self.func {
587            ScalarWindowFunc::RowNumber => {}
588            ScalarWindowFunc::Rank => {}
589            ScalarWindowFunc::DenseRank => {}
590        }
591        Ok(())
592    }
593
594    fn typ(
595        &self,
596        _outers: &[SqlRelationType],
597        _inner: &SqlRelationType,
598        _params: &BTreeMap<usize, SqlScalarType>,
599    ) -> SqlColumnType {
600        self.func.output_sql_type()
601    }
602
603    pub fn into_expr(self) -> mz_expr::AggregateFunc {
604        match self.func {
605            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
606                order_by: self.order_by,
607            },
608            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
609                order_by: self.order_by,
610            },
611            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
612                order_by: self.order_by,
613            },
614        }
615    }
616}
617
618#[derive(
619    Debug,
620    Clone,
621    PartialEq,
622    Eq,
623    PartialOrd,
624    Ord,
625    Hash,
626    Serialize,
627    Deserialize
628)]
629/// Scalar Window functions
630pub enum ScalarWindowFunc {
631    RowNumber,
632    Rank,
633    DenseRank,
634}
635
636impl Display for ScalarWindowFunc {
637    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
638        match self {
639            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
640            ScalarWindowFunc::Rank => write!(f, "rank"),
641            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
642        }
643    }
644}
645
646impl ScalarWindowFunc {
647    pub fn output_sql_type(&self) -> SqlColumnType {
648        match self {
649            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
650            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
651            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
652        }
653    }
654}
655
656#[derive(
657    Debug,
658    Clone,
659    PartialEq,
660    Eq,
661    PartialOrd,
662    Ord,
663    Hash,
664    Serialize,
665    Deserialize
666)]
667pub struct ValueWindowExpr {
668    pub func: ValueWindowFunc,
669    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
670    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
671    /// e.g., `row(#1, 3, null)`.
672    /// If it's a fused window function, then the arguments of each of the constituent function
673    /// calls are wrapped in an outer record.
674    pub args: Box<HirScalarExpr>,
675    /// See comment on `WindowExpr::order_by`.
676    pub order_by: Vec<ColumnOrder>,
677    pub window_frame: WindowFrame,
678    pub ignore_nulls: bool,
679}
680
681impl Display for ValueWindowFunc {
682    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
683        match self {
684            ValueWindowFunc::Lag => write!(f, "lag"),
685            ValueWindowFunc::Lead => write!(f, "lead"),
686            ValueWindowFunc::FirstValue => write!(f, "first_value"),
687            ValueWindowFunc::LastValue => write!(f, "last_value"),
688            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
689        }
690    }
691}
692
693impl ValueWindowExpr {
694    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
695    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
696    where
697        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
698    {
699        f(&self.args)
700    }
701
702    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
703    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
704    where
705        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
706    {
707        f(&mut self.args)
708    }
709
710    fn typ(
711        &self,
712        outers: &[SqlRelationType],
713        inner: &SqlRelationType,
714        params: &BTreeMap<usize, SqlScalarType>,
715    ) -> SqlColumnType {
716        self.func
717            .output_sql_type(self.args.typ(outers, inner, params))
718    }
719
720    /// Converts into `mz_expr::AggregateFunc`.
721    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
722        (
723            self.args,
724            self.func
725                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
726        )
727    }
728}
729
730/// Yields `args` (the value-window function's argument expression);
731/// does not descend into it.
732impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
733    // `visit_children` and friends are not implemented explicitly: the trait
734    // defaults delegate to `children`/`children_mut` below, which yield the
735    // single argument expression.
736
737    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
738    where
739        HirScalarExpr: 'a,
740    {
741        // Yield `args` itself; we must not descend into it (see the impl-level
742        // doc comment). Descending would skip `args` when it is a leaf node
743        // (e.g. `first_value(mz_now())`), breaking `contains_temporal`.
744        std::iter::once(&*self.args)
745    }
746
747    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
748    where
749        HirScalarExpr: 'a,
750    {
751        std::iter::once(&mut *self.args)
752    }
753}
754
755#[derive(
756    Debug,
757    Clone,
758    PartialEq,
759    Eq,
760    PartialOrd,
761    Ord,
762    Hash,
763    Serialize,
764    Deserialize
765)]
766/// Value Window functions
767pub enum ValueWindowFunc {
768    Lag,
769    Lead,
770    FirstValue,
771    LastValue,
772    Fused(Vec<ValueWindowFunc>),
773}
774
775impl ValueWindowFunc {
776    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
777        match self {
778            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
779                // The input is a (value, offset, default) record, so extract the type of the first arg
780                input_type.scalar_type.unwrap_record_element_type()[0]
781                    .clone()
782                    .nullable(true)
783            }
784            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
785                input_type.scalar_type.nullable(true)
786            }
787            ValueWindowFunc::Fused(funcs) => {
788                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
789                SqlScalarType::Record {
790                    fields: funcs
791                        .iter()
792                        .zip_eq(input_types)
793                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
794                        .collect(),
795                    custom_id: None,
796                }
797                .nullable(false)
798            }
799        }
800    }
801
802    pub fn into_expr(
803        self,
804        order_by: Vec<ColumnOrder>,
805        window_frame: WindowFrame,
806        ignore_nulls: bool,
807    ) -> mz_expr::AggregateFunc {
808        match self {
809            // Lag and Lead are fundamentally the same function, just with opposite directions
810            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
811                order_by,
812                lag_lead: mz_expr::LagLeadType::Lag,
813                ignore_nulls,
814            },
815            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
816                order_by,
817                lag_lead: mz_expr::LagLeadType::Lead,
818                ignore_nulls,
819            },
820            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
821                order_by,
822                window_frame,
823            },
824            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
825                order_by,
826                window_frame,
827            },
828            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
829                funcs: funcs
830                    .into_iter()
831                    .map(|func| {
832                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
833                    })
834                    .collect(),
835                order_by,
836            },
837        }
838    }
839}
840
841#[derive(
842    Debug,
843    Clone,
844    PartialEq,
845    Eq,
846    PartialOrd,
847    Ord,
848    Hash,
849    Serialize,
850    Deserialize
851)]
852pub struct AggregateWindowExpr {
853    pub aggregate_expr: AggregateExpr,
854    pub order_by: Vec<ColumnOrder>,
855    pub window_frame: WindowFrame,
856}
857
858impl AggregateWindowExpr {
859    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
860    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
861    where
862        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
863    {
864        f(&self.aggregate_expr.expr)
865    }
866
867    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
868    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
869    where
870        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
871    {
872        f(&mut self.aggregate_expr.expr)
873    }
874
875    fn typ(
876        &self,
877        outers: &[SqlRelationType],
878        inner: &SqlRelationType,
879        params: &BTreeMap<usize, SqlScalarType>,
880    ) -> SqlColumnType {
881        self.aggregate_expr
882            .func
883            .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params))
884    }
885
886    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
887        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
888            (
889                self.aggregate_expr.expr,
890                FusedWindowAggregate {
891                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
892                    order_by: self.order_by,
893                    window_frame: self.window_frame,
894                },
895            )
896        } else {
897            (
898                self.aggregate_expr.expr,
899                WindowAggregate {
900                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
901                    order_by: self.order_by,
902                    window_frame: self.window_frame,
903                },
904            )
905        }
906    }
907}
908
909/// Yields the aggregate's argument expression (`aggregate_expr.expr`);
910/// does not descend into it.
911impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
912    // `visit_children` and friends are not implemented explicitly: the trait
913    // defaults delegate to `children`/`children_mut` below, which yield the
914    // single aggregate argument expression.
915
916    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
917    where
918        HirScalarExpr: 'a,
919    {
920        // Yield the aggregate's argument expression itself; we must not descend
921        // into it. Descending would skip the argument when it is a leaf node
922        // (e.g. `max(mz_now())`), breaking `contains_temporal`.
923        std::iter::once(&*self.aggregate_expr.expr)
924    }
925
926    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
927    where
928        HirScalarExpr: 'a,
929    {
930        std::iter::once(&mut *self.aggregate_expr.expr)
931    }
932}
933
934/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
935/// determined. Several SQL expressions can be freely coerced based upon where
936/// in the expression tree they appear. For example, the string literal '42'
937/// will be automatically coerced to the integer 42 if used in a numeric
938/// context:
939///
940/// ```sql
941/// SELECT '42' + 42
942/// ```
943///
944/// This separate type gives the code that needs to interact with coercions very
945/// fine-grained control over what coercions happen and when.
946///
947/// The primary driver of coercion is function and operator selection, as
948/// choosing the correct function or operator implementation depends on the type
949/// of the provided arguments. Coercion also occurs at the very root of the
950/// scalar expression tree. For example in
951///
952/// ```sql
953/// SELECT ... WHERE $1
954/// ```
955///
956/// the `WHERE` clause will coerce the contained unconstrained type parameter
957/// `$1` to have type bool.
958#[derive(Clone, Debug)]
959pub enum CoercibleScalarExpr {
960    Coerced(HirScalarExpr),
961    Parameter(usize),
962    LiteralNull,
963    LiteralString(String),
964    LiteralRecord(Vec<CoercibleScalarExpr>),
965}
966
967impl CoercibleScalarExpr {
968    pub fn type_as(
969        self,
970        ecx: &ExprContext,
971        ty: &SqlScalarType,
972    ) -> Result<HirScalarExpr, PlanError> {
973        let expr = typeconv::plan_coerce(ecx, self, ty)?;
974        let expr_ty = ecx.scalar_type(&expr);
975        if ty != &expr_ty {
976            sql_bail!(
977                "{} must have type {}, not type {}",
978                ecx.name,
979                ecx.humanize_sql_scalar_type(ty, false),
980                ecx.humanize_sql_scalar_type(&expr_ty, false),
981            );
982        }
983        Ok(expr)
984    }
985
986    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
987        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
988    }
989
990    pub fn cast_to(
991        self,
992        ecx: &ExprContext,
993        ccx: CastContext,
994        ty: &SqlScalarType,
995    ) -> Result<HirScalarExpr, PlanError> {
996        let expr = typeconv::plan_coerce(ecx, self, ty)?;
997        typeconv::plan_cast(ecx, ccx, expr, ty)
998    }
999}
1000
1001/// The column type for a [`CoercibleScalarExpr`].
1002#[derive(Clone, Debug)]
1003pub enum CoercibleColumnType {
1004    Coerced(SqlColumnType),
1005    Record(Vec<CoercibleColumnType>),
1006    Uncoerced,
1007}
1008
1009impl CoercibleColumnType {
1010    /// Reports the nullability of the type.
1011    pub fn nullable(&self) -> bool {
1012        match self {
1013            // A coerced value's nullability is known.
1014            CoercibleColumnType::Coerced(ct) => ct.nullable,
1015
1016            // A literal record can never be null.
1017            CoercibleColumnType::Record(_) => false,
1018
1019            // An uncoerced literal may be the literal `NULL`, so we have
1020            // to conservatively assume it is nullable.
1021            CoercibleColumnType::Uncoerced => true,
1022        }
1023    }
1024}
1025
1026/// The scalar type for a [`CoercibleScalarExpr`].
1027#[derive(Clone, Debug)]
1028pub enum CoercibleScalarType {
1029    Coerced(SqlScalarType),
1030    Record(Vec<CoercibleColumnType>),
1031    Uncoerced,
1032}
1033
1034impl CoercibleScalarType {
1035    /// Reports whether the scalar type has been coerced.
1036    pub fn is_coerced(&self) -> bool {
1037        matches!(self, CoercibleScalarType::Coerced(_))
1038    }
1039
1040    /// Returns the coerced scalar type, if the type is coerced.
1041    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1042        match self {
1043            CoercibleScalarType::Coerced(t) => Some(t),
1044            _ => None,
1045        }
1046    }
1047
1048    /// If the type is coerced, apply the mapping function to the contained
1049    /// scalar type.
1050    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1051    where
1052        F: FnOnce(SqlScalarType) -> SqlScalarType,
1053    {
1054        match self {
1055            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1056            _ => self,
1057        }
1058    }
1059
1060    /// If the type is an coercible record, forcibly converts to a coerced
1061    /// record type. Any uncoerced field types are assumed to be of type text.
1062    ///
1063    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
1064    /// accepts a type hint that can indicate the types of uncoerced field
1065    /// types.
1066    pub fn force_coerced_if_record(&mut self) {
1067        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1068            let mut fields = vec![];
1069            for (i, uf) in uncoerced_fields.enumerate() {
1070                let name = ColumnName::from(format!("f{}", i + 1));
1071                let ty = match uf {
1072                    CoercibleColumnType::Coerced(ty) => ty,
1073                    CoercibleColumnType::Record(mut fields) => {
1074                        convert(fields.drain(..)).nullable(false)
1075                    }
1076                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1077                };
1078                fields.push((name, ty))
1079            }
1080            SqlScalarType::Record {
1081                fields: fields.into(),
1082                custom_id: None,
1083            }
1084        }
1085
1086        if let CoercibleScalarType::Record(fields) = self {
1087            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1088        }
1089    }
1090}
1091
1092/// An expression whose type can be ascertained.
1093///
1094/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
1095pub trait AbstractExpr {
1096    type Type: AbstractColumnType;
1097
1098    /// Computes the type of the expression.
1099    fn typ(
1100        &self,
1101        outers: &[SqlRelationType],
1102        inner: &SqlRelationType,
1103        params: &BTreeMap<usize, SqlScalarType>,
1104    ) -> Self::Type;
1105}
1106
1107impl AbstractExpr for CoercibleScalarExpr {
1108    type Type = CoercibleColumnType;
1109
1110    fn typ(
1111        &self,
1112        outers: &[SqlRelationType],
1113        inner: &SqlRelationType,
1114        params: &BTreeMap<usize, SqlScalarType>,
1115    ) -> Self::Type {
1116        match self {
1117            CoercibleScalarExpr::Coerced(expr) => {
1118                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1119            }
1120            CoercibleScalarExpr::LiteralRecord(scalars) => {
1121                let fields = scalars
1122                    .iter()
1123                    .map(|s| s.typ(outers, inner, params))
1124                    .collect();
1125                CoercibleColumnType::Record(fields)
1126            }
1127            _ => CoercibleColumnType::Uncoerced,
1128        }
1129    }
1130}
1131
1132/// A column type-like object whose underlying scalar type-like object can be
1133/// ascertained.
1134///
1135/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1136pub trait AbstractColumnType {
1137    type AbstractScalarType;
1138
1139    /// Converts the column type-like object into its inner scalar type-like
1140    /// object.
1141    fn scalar_type(self) -> Self::AbstractScalarType;
1142}
1143
1144impl AbstractColumnType for SqlColumnType {
1145    type AbstractScalarType = SqlScalarType;
1146
1147    fn scalar_type(self) -> Self::AbstractScalarType {
1148        self.scalar_type
1149    }
1150}
1151
1152impl AbstractColumnType for CoercibleColumnType {
1153    type AbstractScalarType = CoercibleScalarType;
1154
1155    fn scalar_type(self) -> Self::AbstractScalarType {
1156        match self {
1157            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1158            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1159            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1160        }
1161    }
1162}
1163
1164impl From<HirScalarExpr> for CoercibleScalarExpr {
1165    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1166        CoercibleScalarExpr::Coerced(expr)
1167    }
1168}
1169
1170/// A leveled column reference.
1171///
1172/// In the course of decorrelation, multiple levels of nested subqueries are
1173/// traversed, and references to columns may correspond to different levels
1174/// of containing outer subqueries.
1175///
1176/// A `ColumnRef` allows expressions to refer to columns while being clear
1177/// about which level the column references without manually performing the
1178/// bookkeeping tracking their actual column locations.
1179///
1180/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1181/// from the reference, using `column` as a unique identifier in that subquery level.
1182/// A `level` of zero corresponds to the current scope, and levels increase to
1183/// indicate subqueries further "outwards".
1184#[derive(
1185    Debug,
1186    Clone,
1187    Copy,
1188    PartialEq,
1189    Eq,
1190    Hash,
1191    Ord,
1192    PartialOrd,
1193    Serialize,
1194    Deserialize
1195)]
1196pub struct ColumnRef {
1197    // scope level, where 0 is the current scope and 1+ are outer scopes.
1198    pub level: usize,
1199    // level-local column identifier used.
1200    pub column: usize,
1201}
1202
1203#[derive(
1204    Debug,
1205    Clone,
1206    PartialEq,
1207    Eq,
1208    PartialOrd,
1209    Ord,
1210    Hash,
1211    Serialize,
1212    Deserialize
1213)]
1214pub enum JoinKind {
1215    Inner,
1216    LeftOuter,
1217    RightOuter,
1218    FullOuter,
1219}
1220
1221impl fmt::Display for JoinKind {
1222    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1223        write!(
1224            f,
1225            "{}",
1226            match self {
1227                JoinKind::Inner => "Inner",
1228                JoinKind::LeftOuter => "LeftOuter",
1229                JoinKind::RightOuter => "RightOuter",
1230                JoinKind::FullOuter => "FullOuter",
1231            }
1232        )
1233    }
1234}
1235
1236impl JoinKind {
1237    pub fn can_be_correlated(&self) -> bool {
1238        match self {
1239            JoinKind::Inner | JoinKind::LeftOuter => true,
1240            JoinKind::RightOuter | JoinKind::FullOuter => false,
1241        }
1242    }
1243
1244    pub fn can_elide_identity_left_join(&self) -> bool {
1245        match self {
1246            JoinKind::Inner | JoinKind::RightOuter => true,
1247            JoinKind::LeftOuter | JoinKind::FullOuter => false,
1248        }
1249    }
1250
1251    pub fn can_elide_identity_right_join(&self) -> bool {
1252        match self {
1253            JoinKind::Inner | JoinKind::LeftOuter => true,
1254            JoinKind::RightOuter | JoinKind::FullOuter => false,
1255        }
1256    }
1257}
1258
1259#[derive(
1260    Debug,
1261    Clone,
1262    PartialEq,
1263    Eq,
1264    PartialOrd,
1265    Ord,
1266    Hash,
1267    Serialize,
1268    Deserialize
1269)]
1270pub struct AggregateExpr {
1271    pub func: AggregateFunc,
1272    pub expr: Box<HirScalarExpr>,
1273    pub distinct: bool,
1274}
1275
1276/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1277/// types may be different.
1278///
1279/// Specifically, the nullability of the aggregate columns is more common
1280/// here than in `expr`, as these aggregates may be applied over empty
1281/// result sets and should be null in those cases, whereas `expr` variants
1282/// only return null values when supplied nulls as input.
1283#[derive(
1284    Clone,
1285    Debug,
1286    Eq,
1287    PartialEq,
1288    PartialOrd,
1289    Ord,
1290    Hash,
1291    Serialize,
1292    Deserialize
1293)]
1294pub enum AggregateFunc {
1295    MaxNumeric,
1296    MaxInt16,
1297    MaxInt32,
1298    MaxInt64,
1299    MaxUInt16,
1300    MaxUInt32,
1301    MaxUInt64,
1302    MaxMzTimestamp,
1303    MaxFloat32,
1304    MaxFloat64,
1305    MaxBool,
1306    MaxString,
1307    MaxDate,
1308    MaxTimestamp,
1309    MaxTimestampTz,
1310    MaxInterval,
1311    MaxTime,
1312    MinNumeric,
1313    MinInt16,
1314    MinInt32,
1315    MinInt64,
1316    MinUInt16,
1317    MinUInt32,
1318    MinUInt64,
1319    MinMzTimestamp,
1320    MinFloat32,
1321    MinFloat64,
1322    MinBool,
1323    MinString,
1324    MinDate,
1325    MinTimestamp,
1326    MinTimestampTz,
1327    MinInterval,
1328    MinTime,
1329    SumInt16,
1330    SumInt32,
1331    SumInt64,
1332    SumUInt16,
1333    SumUInt32,
1334    SumUInt64,
1335    SumFloat32,
1336    SumFloat64,
1337    SumNumeric,
1338    Count,
1339    Any,
1340    All,
1341    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1342    /// into a JSON list. The other elements are columns used by `order_by`.
1343    ///
1344    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1345    /// layer, this function filters out `Datum::Null`, for consistency with
1346    /// the other aggregate functions.
1347    JsonbAgg {
1348        order_by: Vec<ColumnOrder>,
1349    },
1350    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1351    /// JSON map. The other elements are columns used by `order_by`.
1352    JsonbObjectAgg {
1353        order_by: Vec<ColumnOrder>,
1354    },
1355    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1356    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1357    /// elements are columns used by `order_by`.
1358    MapAgg {
1359        order_by: Vec<ColumnOrder>,
1360        value_type: SqlScalarType,
1361    },
1362    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1363    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1364    ArrayConcat {
1365        order_by: Vec<ColumnOrder>,
1366    },
1367    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1368    /// single `Datum::List`. The other elements are columns used by `order_by`.
1369    ListConcat {
1370        order_by: Vec<ColumnOrder>,
1371    },
1372    StringAgg {
1373        order_by: Vec<ColumnOrder>,
1374    },
1375    /// A bundle of fused window aggregations: its input is a record, whose each
1376    /// component will be the input to one of the `AggregateFunc`s.
1377    ///
1378    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1379    /// more specifically an `AggregateWindowExpr`.
1380    FusedWindowAgg {
1381        funcs: Vec<AggregateFunc>,
1382    },
1383    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1384    ///
1385    /// Useful for removing an expensive aggregation while maintaining the shape
1386    /// of a reduce operator.
1387    Dummy,
1388}
1389
1390impl AggregateFunc {
1391    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1392    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1393        match self {
1394            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1395            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1396            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1397            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1398            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1399            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1400            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1401            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1402            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1403            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1404            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1405            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1406            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1407            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1408            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1409            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1410            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1411            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1412            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1413            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1414            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1415            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1416            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1417            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1418            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1419            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1420            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1421            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1422            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1423            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1424            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1425            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1426            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1427            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1428            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1429            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1430            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1431            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1432            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1433            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1434            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1435            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1436            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1437            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1438            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1439            AggregateFunc::All => mz_expr::AggregateFunc::All,
1440            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1441            AggregateFunc::JsonbObjectAgg { order_by } => {
1442                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1443            }
1444            AggregateFunc::MapAgg {
1445                order_by,
1446                value_type,
1447            } => mz_expr::AggregateFunc::MapAgg {
1448                order_by,
1449                value_type,
1450            },
1451            AggregateFunc::ArrayConcat { order_by } => {
1452                mz_expr::AggregateFunc::ArrayConcat { order_by }
1453            }
1454            AggregateFunc::ListConcat { order_by } => {
1455                mz_expr::AggregateFunc::ListConcat { order_by }
1456            }
1457            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1458            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1459            // `AggregateWindowExpr::into_expr`.
1460            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1461                panic!("into_expr called on FusedWindowAgg")
1462            }
1463            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1464        }
1465    }
1466
1467    /// Returns a datum whose inclusion in the aggregation will not change its
1468    /// result.
1469    ///
1470    /// # Panics
1471    ///
1472    /// Panics if called on a `FusedWindowAgg`.
1473    pub fn identity_datum(&self) -> Datum<'static> {
1474        match self {
1475            AggregateFunc::Any => Datum::False,
1476            AggregateFunc::All => Datum::True,
1477            AggregateFunc::Dummy => Datum::Dummy,
1478            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1479            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1480            AggregateFunc::MaxNumeric
1481            | AggregateFunc::MaxInt16
1482            | AggregateFunc::MaxInt32
1483            | AggregateFunc::MaxInt64
1484            | AggregateFunc::MaxUInt16
1485            | AggregateFunc::MaxUInt32
1486            | AggregateFunc::MaxUInt64
1487            | AggregateFunc::MaxMzTimestamp
1488            | AggregateFunc::MaxFloat32
1489            | AggregateFunc::MaxFloat64
1490            | AggregateFunc::MaxBool
1491            | AggregateFunc::MaxString
1492            | AggregateFunc::MaxDate
1493            | AggregateFunc::MaxTimestamp
1494            | AggregateFunc::MaxTimestampTz
1495            | AggregateFunc::MaxInterval
1496            | AggregateFunc::MaxTime
1497            | AggregateFunc::MinNumeric
1498            | AggregateFunc::MinInt16
1499            | AggregateFunc::MinInt32
1500            | AggregateFunc::MinInt64
1501            | AggregateFunc::MinUInt16
1502            | AggregateFunc::MinUInt32
1503            | AggregateFunc::MinUInt64
1504            | AggregateFunc::MinMzTimestamp
1505            | AggregateFunc::MinFloat32
1506            | AggregateFunc::MinFloat64
1507            | AggregateFunc::MinBool
1508            | AggregateFunc::MinString
1509            | AggregateFunc::MinDate
1510            | AggregateFunc::MinTimestamp
1511            | AggregateFunc::MinTimestampTz
1512            | AggregateFunc::MinInterval
1513            | AggregateFunc::MinTime
1514            | AggregateFunc::SumInt16
1515            | AggregateFunc::SumInt32
1516            | AggregateFunc::SumInt64
1517            | AggregateFunc::SumUInt16
1518            | AggregateFunc::SumUInt32
1519            | AggregateFunc::SumUInt64
1520            | AggregateFunc::SumFloat32
1521            | AggregateFunc::SumFloat64
1522            | AggregateFunc::SumNumeric
1523            | AggregateFunc::Count
1524            | AggregateFunc::JsonbAgg { .. }
1525            | AggregateFunc::JsonbObjectAgg { .. }
1526            | AggregateFunc::MapAgg { .. }
1527            | AggregateFunc::StringAgg { .. } => Datum::Null,
1528            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1529                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1530                // in HIR planning, because it is introduced only during HIR transformation.
1531                //
1532                // The implementation could be something like the following, except that we need to
1533                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1534                // ```
1535                // let temp_storage = RowArena::new();
1536                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1537                // ```
1538                panic!("FusedWindowAgg doesn't have an identity_datum")
1539            }
1540        }
1541    }
1542
1543    /// The output column type for the result of an aggregation.
1544    ///
1545    /// The output column type also contains nullability information, which
1546    /// is (without further information) true for aggregations that are not
1547    /// counts.
1548    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1549        let scalar_type = match self {
1550            AggregateFunc::Count => SqlScalarType::Int64,
1551            AggregateFunc::Any => SqlScalarType::Bool,
1552            AggregateFunc::All => SqlScalarType::Bool,
1553            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1554            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1555            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1556            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1557            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1558                max_scale: Some(NumericMaxScale::ZERO),
1559            },
1560            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1561            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1562                max_scale: Some(NumericMaxScale::ZERO),
1563            },
1564            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1565                value_type: Box::new(value_type.clone()),
1566                custom_id: None,
1567            },
1568            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1569                match input_type.scalar_type {
1570                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1571                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1572                    _ => unreachable!(),
1573                }
1574            }
1575            AggregateFunc::MaxNumeric
1576            | AggregateFunc::MaxInt16
1577            | AggregateFunc::MaxInt32
1578            | AggregateFunc::MaxInt64
1579            | AggregateFunc::MaxUInt16
1580            | AggregateFunc::MaxUInt32
1581            | AggregateFunc::MaxUInt64
1582            | AggregateFunc::MaxMzTimestamp
1583            | AggregateFunc::MaxFloat32
1584            | AggregateFunc::MaxFloat64
1585            | AggregateFunc::MaxBool
1586            | AggregateFunc::MaxString
1587            | AggregateFunc::MaxDate
1588            | AggregateFunc::MaxTimestamp
1589            | AggregateFunc::MaxTimestampTz
1590            | AggregateFunc::MaxInterval
1591            | AggregateFunc::MaxTime
1592            | AggregateFunc::MinNumeric
1593            | AggregateFunc::MinInt16
1594            | AggregateFunc::MinInt32
1595            | AggregateFunc::MinInt64
1596            | AggregateFunc::MinUInt16
1597            | AggregateFunc::MinUInt32
1598            | AggregateFunc::MinUInt64
1599            | AggregateFunc::MinMzTimestamp
1600            | AggregateFunc::MinFloat32
1601            | AggregateFunc::MinFloat64
1602            | AggregateFunc::MinBool
1603            | AggregateFunc::MinString
1604            | AggregateFunc::MinDate
1605            | AggregateFunc::MinTimestamp
1606            | AggregateFunc::MinTimestampTz
1607            | AggregateFunc::MinInterval
1608            | AggregateFunc::MinTime
1609            | AggregateFunc::SumFloat32
1610            | AggregateFunc::SumFloat64
1611            | AggregateFunc::SumNumeric
1612            | AggregateFunc::Dummy => input_type.scalar_type,
1613            AggregateFunc::FusedWindowAgg { funcs } => {
1614                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1615                SqlScalarType::Record {
1616                    fields: funcs
1617                        .iter()
1618                        .zip_eq(input_types)
1619                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
1620                        .collect(),
1621                    custom_id: None,
1622                }
1623            }
1624        };
1625        // max/min/sum return null on empty sets
1626        let nullable = !matches!(self, AggregateFunc::Count);
1627        scalar_type.nullable(nullable)
1628    }
1629
1630    pub fn is_order_sensitive(&self) -> bool {
1631        use AggregateFunc::*;
1632        matches!(
1633            self,
1634            JsonbAgg { .. }
1635                | JsonbObjectAgg { .. }
1636                | MapAgg { .. }
1637                | ArrayConcat { .. }
1638                | ListConcat { .. }
1639                | StringAgg { .. }
1640        )
1641    }
1642}
1643
1644impl HirRelationExpr {
1645    /// Gets the SQL type of a self-contained, top-level expression.
1646    pub fn top_level_typ(&self) -> SqlRelationType {
1647        self.typ(&[], &BTreeMap::new())
1648    }
1649
1650    /// Gets the SQL type of the expression.
1651    ///
1652    /// `outers` gives types for outer relations.
1653    /// `params` gives types for parameters.
1654    pub fn typ(
1655        &self,
1656        outers: &[SqlRelationType],
1657        params: &BTreeMap<usize, SqlScalarType>,
1658    ) -> SqlRelationType {
1659        stack::maybe_grow(|| match self {
1660            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1661            HirRelationExpr::Get { typ, .. } => typ.clone(),
1662            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1663            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1664            HirRelationExpr::Project { input, outputs } => {
1665                let input_typ = input.typ(outers, params);
1666                SqlRelationType::new(
1667                    outputs
1668                        .iter()
1669                        .map(|&i| input_typ.column_types[i].clone())
1670                        .collect(),
1671                )
1672            }
1673            HirRelationExpr::Map { input, scalars } => {
1674                let mut typ = input.typ(outers, params);
1675                for scalar in scalars {
1676                    typ.column_types.push(scalar.typ(outers, &typ, params));
1677                }
1678                typ
1679            }
1680            HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(),
1681            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1682                input.typ(outers, params)
1683            }
1684            HirRelationExpr::Join {
1685                left, right, kind, ..
1686            } => {
1687                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1688                let right_nullable =
1689                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1690                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1691                    let nullable = t.nullable || left_nullable;
1692                    t.nullable(nullable)
1693                });
1694                let mut outers = outers.to_vec();
1695                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1696                let rt = right
1697                    .typ(&outers, params)
1698                    .column_types
1699                    .into_iter()
1700                    .map(|t| {
1701                        let nullable = t.nullable || right_nullable;
1702                        t.nullable(nullable)
1703                    });
1704                SqlRelationType::new(lt.chain(rt).collect())
1705            }
1706            HirRelationExpr::Reduce {
1707                input,
1708                group_key,
1709                aggregates,
1710                expected_group_size: _,
1711            } => {
1712                let input_typ = input.typ(outers, params);
1713                let mut column_types = group_key
1714                    .iter()
1715                    .map(|&i| input_typ.column_types[i].clone())
1716                    .collect::<Vec<_>>();
1717                for agg in aggregates {
1718                    column_types.push(agg.typ(outers, &input_typ, params));
1719                }
1720                // TODO(frank): add primary key information.
1721                SqlRelationType::new(column_types)
1722            }
1723            // TODO(frank): check for removal; add primary key information.
1724            HirRelationExpr::Distinct { input }
1725            | HirRelationExpr::Negate { input }
1726            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1727            HirRelationExpr::Union { base, inputs } => {
1728                let mut base_cols = base.typ(outers, params).column_types;
1729                for input in inputs {
1730                    for (base_col, col) in base_cols
1731                        .iter_mut()
1732                        .zip_eq(input.typ(outers, params).column_types)
1733                    {
1734                        *base_col = base_col.sql_union(&col).unwrap(); // HIR deliberately not using `union`
1735                    }
1736                }
1737                SqlRelationType::new(base_cols)
1738            }
1739        })
1740    }
1741
1742    pub fn arity(&self) -> usize {
1743        match self {
1744            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1745            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1746            HirRelationExpr::Let { body, .. } => body.arity(),
1747            HirRelationExpr::LetRec { body, .. } => body.arity(),
1748            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1749            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1750            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1751            HirRelationExpr::Filter { input, .. }
1752            | HirRelationExpr::TopK { input, .. }
1753            | HirRelationExpr::Distinct { input }
1754            | HirRelationExpr::Negate { input }
1755            | HirRelationExpr::Threshold { input } => input.arity(),
1756            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1757            HirRelationExpr::Union { base, .. } => base.arity(),
1758            HirRelationExpr::Reduce {
1759                group_key,
1760                aggregates,
1761                ..
1762            } => group_key.len() + aggregates.len(),
1763        }
1764    }
1765
1766    /// If self is a constant, return the value and the type, otherwise `None`.
1767    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1768        match self {
1769            Self::Constant { rows, typ } => Some((rows, typ)),
1770            _ => None,
1771        }
1772    }
1773
1774    /// Reports whether this expression contains a column reference to its
1775    /// direct parent scope.
1776    pub fn is_correlated(&self) -> bool {
1777        let mut correlated = false;
1778        #[allow(deprecated)]
1779        self.visit_columns(0, &mut |depth, col| {
1780            if col.level > depth && col.level - depth == 1 {
1781                correlated = true;
1782            }
1783        });
1784        correlated
1785    }
1786
1787    pub fn is_join_identity(&self) -> bool {
1788        match self {
1789            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1790            _ => false,
1791        }
1792    }
1793
1794    pub fn project(self, outputs: Vec<usize>) -> Self {
1795        if outputs.iter().copied().eq(0..self.arity()) {
1796            // The projection is trivial. Suppress it.
1797            self
1798        } else {
1799            HirRelationExpr::Project {
1800                input: Box::new(self),
1801                outputs,
1802            }
1803        }
1804    }
1805
1806    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1807        if scalars.is_empty() {
1808            // The map is trivial. Suppress it.
1809            self
1810        } else if let HirRelationExpr::Map {
1811            scalars: old_scalars,
1812            input: _,
1813        } = &mut self
1814        {
1815            // Map applied to a map. Fuse the maps.
1816            old_scalars.extend(scalars);
1817            self
1818        } else {
1819            HirRelationExpr::Map {
1820                input: Box::new(self),
1821                scalars,
1822            }
1823        }
1824    }
1825
1826    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1827        if let HirRelationExpr::Filter {
1828            input: _,
1829            predicates,
1830        } = &mut self
1831        {
1832            predicates.extend(preds);
1833            predicates.sort();
1834            predicates.dedup();
1835            self
1836        } else {
1837            preds.sort();
1838            preds.dedup();
1839            HirRelationExpr::Filter {
1840                input: Box::new(self),
1841                predicates: preds,
1842            }
1843        }
1844    }
1845
1846    pub fn reduce(
1847        self,
1848        group_key: Vec<usize>,
1849        aggregates: Vec<AggregateExpr>,
1850        expected_group_size: Option<u64>,
1851    ) -> Self {
1852        HirRelationExpr::Reduce {
1853            input: Box::new(self),
1854            group_key,
1855            aggregates,
1856            expected_group_size,
1857        }
1858    }
1859
1860    pub fn top_k(
1861        self,
1862        group_key: Vec<usize>,
1863        order_key: Vec<ColumnOrder>,
1864        limit: Option<HirScalarExpr>,
1865        offset: HirScalarExpr,
1866        expected_group_size: Option<u64>,
1867    ) -> Self {
1868        HirRelationExpr::TopK {
1869            input: Box::new(self),
1870            group_key,
1871            order_key,
1872            limit,
1873            offset,
1874            expected_group_size,
1875        }
1876    }
1877
1878    pub fn negate(self) -> Self {
1879        if let HirRelationExpr::Negate { input } = self {
1880            *input
1881        } else {
1882            HirRelationExpr::Negate {
1883                input: Box::new(self),
1884            }
1885        }
1886    }
1887
1888    pub fn distinct(self) -> Self {
1889        if let HirRelationExpr::Distinct { .. } = self {
1890            self
1891        } else {
1892            HirRelationExpr::Distinct {
1893                input: Box::new(self),
1894            }
1895        }
1896    }
1897
1898    pub fn threshold(self) -> Self {
1899        if let HirRelationExpr::Threshold { .. } = self {
1900            self
1901        } else {
1902            HirRelationExpr::Threshold {
1903                input: Box::new(self),
1904            }
1905        }
1906    }
1907
1908    pub fn union(self, other: Self) -> Self {
1909        let mut terms = Vec::new();
1910        if let HirRelationExpr::Union { base, inputs } = self {
1911            terms.push(*base);
1912            terms.extend(inputs);
1913        } else {
1914            terms.push(self);
1915        }
1916        if let HirRelationExpr::Union { base, inputs } = other {
1917            terms.push(*base);
1918            terms.extend(inputs);
1919        } else {
1920            terms.push(other);
1921        }
1922        HirRelationExpr::Union {
1923            base: Box::new(terms.remove(0)),
1924            inputs: terms,
1925        }
1926    }
1927
1928    pub fn exists(self) -> HirScalarExpr {
1929        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1930    }
1931
1932    pub fn select(self) -> HirScalarExpr {
1933        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1934    }
1935
1936    pub fn join(
1937        self,
1938        mut right: HirRelationExpr,
1939        on: HirScalarExpr,
1940        kind: JoinKind,
1941    ) -> HirRelationExpr {
1942        if self.is_join_identity()
1943            && !right.is_correlated()
1944            && on == HirScalarExpr::literal_true()
1945            && kind.can_elide_identity_left_join()
1946        {
1947            // The join can be elided, but we need to adjust column references
1948            // on the right-hand side to account for the removal of the scope
1949            // introduced by the join.
1950            #[allow(deprecated)]
1951            right.visit_columns_mut(0, &mut |depth, col| {
1952                if col.level > depth {
1953                    col.level -= 1;
1954                }
1955            });
1956            right
1957        } else if right.is_join_identity()
1958            && on == HirScalarExpr::literal_true()
1959            && kind.can_elide_identity_right_join()
1960        {
1961            self
1962        } else {
1963            HirRelationExpr::Join {
1964                left: Box::new(self),
1965                right: Box::new(right),
1966                on,
1967                kind,
1968            }
1969        }
1970    }
1971
1972    pub fn take(&mut self) -> HirRelationExpr {
1973        mem::replace(
1974            self,
1975            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1976        )
1977    }
1978
1979    #[deprecated = "Use `Visit::visit_post`."]
1980    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1981    where
1982        F: FnMut(&'a Self, usize),
1983    {
1984        #[allow(deprecated)]
1985        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1986                                                 depth: usize|
1987         -> Result<(), ()> {
1988            f(e, depth);
1989            Ok(())
1990        });
1991    }
1992
1993    #[deprecated = "Use `Visit::try_visit_post`."]
1994    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1995    where
1996        F: FnMut(&'a Self, usize) -> Result<(), E>,
1997    {
1998        #[allow(deprecated)]
1999        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
2000            e.visit_fallible(depth, f)
2001        })?;
2002        f(self, depth)
2003    }
2004
2005    /// WARNING: `VisitChildren<HirRelationExpr>::try_visit_children` is NOT a
2006    /// drop-in replacement: in addition to the relation children visited here,
2007    /// it also descends into every subquery (`Exists`/`Select`) reachable
2008    /// through `self`'s scalar children — at any depth of scalar nesting.
2009    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
2010    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
2011    where
2012        F: FnMut(&'a Self, usize) -> Result<(), E>,
2013    {
2014        match self {
2015            HirRelationExpr::Constant { .. }
2016            | HirRelationExpr::Get { .. }
2017            | HirRelationExpr::CallTable { .. } => (),
2018            HirRelationExpr::Let { body, value, .. } => {
2019                f(value, depth)?;
2020                f(body, depth)?;
2021            }
2022            HirRelationExpr::LetRec {
2023                limit: _,
2024                bindings,
2025                body,
2026            } => {
2027                for (_, _, value, _) in bindings.iter() {
2028                    f(value, depth)?;
2029                }
2030                f(body, depth)?;
2031            }
2032            HirRelationExpr::Project { input, .. } => {
2033                f(input, depth)?;
2034            }
2035            HirRelationExpr::Map { input, .. } => {
2036                f(input, depth)?;
2037            }
2038            HirRelationExpr::Filter { input, .. } => {
2039                f(input, depth)?;
2040            }
2041            HirRelationExpr::Join { left, right, .. } => {
2042                f(left, depth)?;
2043                f(right, depth + 1)?;
2044            }
2045            HirRelationExpr::Reduce { input, .. } => {
2046                f(input, depth)?;
2047            }
2048            HirRelationExpr::Distinct { input } => {
2049                f(input, depth)?;
2050            }
2051            HirRelationExpr::TopK { input, .. } => {
2052                f(input, depth)?;
2053            }
2054            HirRelationExpr::Negate { input } => {
2055                f(input, depth)?;
2056            }
2057            HirRelationExpr::Threshold { input } => {
2058                f(input, depth)?;
2059            }
2060            HirRelationExpr::Union { base, inputs } => {
2061                f(base, depth)?;
2062                for input in inputs {
2063                    f(input, depth)?;
2064                }
2065            }
2066        }
2067        Ok(())
2068    }
2069
2070    #[deprecated = "Use `Visit::visit_mut_post` instead."]
2071    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2072    where
2073        F: FnMut(&mut Self, usize),
2074    {
2075        #[allow(deprecated)]
2076        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2077                                                     depth: usize|
2078         -> Result<(), ()> {
2079            f(e, depth);
2080            Ok(())
2081        });
2082    }
2083
2084    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2085    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2086    where
2087        F: FnMut(&mut Self, usize) -> Result<(), E>,
2088    {
2089        #[allow(deprecated)]
2090        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2091            e.visit_mut_fallible(depth, f)
2092        })?;
2093        f(self, depth)
2094    }
2095
2096    /// WARNING: `VisitChildren<HirRelationExpr>::try_visit_mut_children` is NOT a
2097    /// drop-in replacement: in addition to the relation children visited here,
2098    /// it also descends into every subquery (`Exists`/`Select`) reachable
2099    /// through `self`'s scalar children — at any depth of scalar nesting.
2100    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2101    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2102    where
2103        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2104    {
2105        match self {
2106            HirRelationExpr::Constant { .. }
2107            | HirRelationExpr::Get { .. }
2108            | HirRelationExpr::CallTable { .. } => (),
2109            HirRelationExpr::Let { body, value, .. } => {
2110                f(value, depth)?;
2111                f(body, depth)?;
2112            }
2113            HirRelationExpr::LetRec {
2114                limit: _,
2115                bindings,
2116                body,
2117            } => {
2118                for (_, _, value, _) in bindings.iter_mut() {
2119                    f(value, depth)?;
2120                }
2121                f(body, depth)?;
2122            }
2123            HirRelationExpr::Project { input, .. } => {
2124                f(input, depth)?;
2125            }
2126            HirRelationExpr::Map { input, .. } => {
2127                f(input, depth)?;
2128            }
2129            HirRelationExpr::Filter { input, .. } => {
2130                f(input, depth)?;
2131            }
2132            HirRelationExpr::Join { left, right, .. } => {
2133                f(left, depth)?;
2134                f(right, depth + 1)?;
2135            }
2136            HirRelationExpr::Reduce { input, .. } => {
2137                f(input, depth)?;
2138            }
2139            HirRelationExpr::Distinct { input } => {
2140                f(input, depth)?;
2141            }
2142            HirRelationExpr::TopK { input, .. } => {
2143                f(input, depth)?;
2144            }
2145            HirRelationExpr::Negate { input } => {
2146                f(input, depth)?;
2147            }
2148            HirRelationExpr::Threshold { input } => {
2149                f(input, depth)?;
2150            }
2151            HirRelationExpr::Union { base, inputs } => {
2152                f(base, depth)?;
2153                for input in inputs {
2154                    f(input, depth)?;
2155                }
2156            }
2157        }
2158        Ok(())
2159    }
2160
2161    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2162    /// Visits all scalar expressions directly held by relation nodes within the sub-tree of `self`.
2163    ///
2164    /// Note: this does NOT descend into subqueries that may appear inside the visited
2165    /// `HirScalarExpr`s (i.e., `HirScalarExpr::Exists` / `HirScalarExpr::Select`). The closure
2166    /// `f` is invoked once per top-level scalar expression attached to a relation node, and it
2167    /// is the closure's responsibility to recurse into any subqueries if desired.
2168    ///
2169    /// The `depth` argument is just a seed: it is the value passed to `f` for scalar expressions
2170    /// at the root of `self`, and it is incremented by 1 when descending into the RHS of a
2171    /// `Join` node (the only place this function increments it). It does NOT control or limit
2172    /// recursion; passing `0` is the usual choice.
2173    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2174    where
2175        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2176    {
2177        #[allow(deprecated)]
2178        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2179                                         depth: usize|
2180         -> Result<(), E> {
2181            match e {
2182                HirRelationExpr::Join { on, .. } => {
2183                    f(on, depth)?;
2184                }
2185                HirRelationExpr::Map { scalars, .. } => {
2186                    for scalar in scalars {
2187                        f(scalar, depth)?;
2188                    }
2189                }
2190                HirRelationExpr::CallTable { exprs, .. } => {
2191                    for expr in exprs {
2192                        f(expr, depth)?;
2193                    }
2194                }
2195                HirRelationExpr::Filter { predicates, .. } => {
2196                    for predicate in predicates {
2197                        f(predicate, depth)?;
2198                    }
2199                }
2200                HirRelationExpr::Reduce { aggregates, .. } => {
2201                    for aggregate in aggregates {
2202                        f(&aggregate.expr, depth)?;
2203                    }
2204                }
2205                HirRelationExpr::TopK { limit, offset, .. } => {
2206                    if let Some(limit) = limit {
2207                        f(limit, depth)?;
2208                    }
2209                    f(offset, depth)?;
2210                }
2211                HirRelationExpr::Union { .. }
2212                | HirRelationExpr::Let { .. }
2213                | HirRelationExpr::LetRec { .. }
2214                | HirRelationExpr::Project { .. }
2215                | HirRelationExpr::Distinct { .. }
2216                | HirRelationExpr::Negate { .. }
2217                | HirRelationExpr::Threshold { .. }
2218                | HirRelationExpr::Constant { .. }
2219                | HirRelationExpr::Get { .. } => (),
2220            }
2221            Ok(())
2222        })
2223    }
2224
2225    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2226    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2227    ///
2228    /// In particular, this also does NOT descend into subqueries inside the visited scalar
2229    /// expressions, and `depth` is just a seed for the value passed to `f` (see
2230    /// [`HirRelationExpr::visit_scalar_expressions`]).
2231    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2232    where
2233        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2234    {
2235        #[allow(deprecated)]
2236        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2237                                             depth: usize|
2238         -> Result<(), E> {
2239            match e {
2240                HirRelationExpr::Join { on, .. } => {
2241                    f(on, depth)?;
2242                }
2243                HirRelationExpr::Map { scalars, .. } => {
2244                    for scalar in scalars.iter_mut() {
2245                        f(scalar, depth)?;
2246                    }
2247                }
2248                HirRelationExpr::CallTable { exprs, .. } => {
2249                    for expr in exprs.iter_mut() {
2250                        f(expr, depth)?;
2251                    }
2252                }
2253                HirRelationExpr::Filter { predicates, .. } => {
2254                    for predicate in predicates.iter_mut() {
2255                        f(predicate, depth)?;
2256                    }
2257                }
2258                HirRelationExpr::Reduce { aggregates, .. } => {
2259                    for aggregate in aggregates.iter_mut() {
2260                        f(&mut aggregate.expr, depth)?;
2261                    }
2262                }
2263                HirRelationExpr::TopK { limit, offset, .. } => {
2264                    if let Some(limit) = limit {
2265                        f(limit, depth)?;
2266                    }
2267                    f(offset, depth)?;
2268                }
2269                HirRelationExpr::Union { .. }
2270                | HirRelationExpr::Let { .. }
2271                | HirRelationExpr::LetRec { .. }
2272                | HirRelationExpr::Project { .. }
2273                | HirRelationExpr::Distinct { .. }
2274                | HirRelationExpr::Negate { .. }
2275                | HirRelationExpr::Threshold { .. }
2276                | HirRelationExpr::Constant { .. }
2277                | HirRelationExpr::Get { .. } => (),
2278            }
2279            Ok(())
2280        })
2281    }
2282
2283    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2284    /// Visits the column references in this relation expression.
2285    ///
2286    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2287    /// which will be incremented when entering the RHS of a join or a subquery and
2288    /// presented to the supplied function `f`.
2289    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2290    where
2291        F: FnMut(usize, &ColumnRef),
2292    {
2293        #[allow(deprecated)]
2294        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2295                                                           depth: usize|
2296         -> Result<(), ()> {
2297            e.visit_columns(depth, f);
2298            Ok(())
2299        });
2300    }
2301
2302    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2303    /// Like `visit_columns`, but permits mutating the column references.
2304    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2305    where
2306        F: FnMut(usize, &mut ColumnRef),
2307    {
2308        #[allow(deprecated)]
2309        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2310                                                               depth: usize|
2311         -> Result<(), ()> {
2312            e.visit_columns_mut(depth, f);
2313            Ok(())
2314        });
2315    }
2316
2317    /// Replaces parameter references in the expression with the corresponding datum from `params`.
2318    /// Additionally, it simplifies OFFSET clauses to constants after parameter binding, and checks
2319    /// them for non-negativity.
2320    ///
2321    /// This walks the entire `HirRelationExpr` tree, including subqueries nested inside scalar
2322    /// expressions: `visit_scalar_expressions_mut` covers the scalar expressions held directly
2323    /// by relation nodes, and the corecursive call into
2324    /// [`HirScalarExpr::bind_parameters_and_simplify_offset`] (made by the closure below) is
2325    /// what descends into subqueries occurring inside those scalar expressions.
2326    pub fn bind_parameters_and_simplify_offset(
2327        &mut self,
2328        scx: &StatementContext,
2329        lifetime: QueryLifetime,
2330        params: &Params,
2331    ) -> Result<(), PlanError> {
2332        #[allow(deprecated)]
2333        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2334            e.bind_parameters_and_simplify_offset(scx, lifetime, params)
2335        })?;
2336
2337        // OFFSET clauses in `expr` should become constants with the above binding of parameters.
2338        // Let's check this and simplify them to literals.
2339        self.try_visit_mut_pre(&mut |expr| {
2340            if let HirRelationExpr::TopK { offset, .. } = expr {
2341                let offset_value = offset_into_value(offset.take())?;
2342                *offset = HirScalarExpr::literal(Datum::Int64(offset_value), SqlScalarType::Int64);
2343            }
2344            Ok::<(), PlanError>(())
2345        })
2346        // (We don't need to simplify LIMIT clauses in `expr`, because we can handle non-constant
2347        // expressions there. If they happen to be simplifiable to literals, then the optimizer will do
2348        // so later.)
2349    }
2350
2351    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2352        let mut contains_parameters = false;
2353        #[allow(deprecated)]
2354        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2355            if e.contains_parameters() {
2356                contains_parameters = true;
2357            }
2358            Ok::<(), PlanError>(())
2359        })?;
2360        Ok(contains_parameters)
2361    }
2362
2363    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2364    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2365        #[allow(deprecated)]
2366        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2367                                                               depth: usize|
2368         -> Result<(), ()> {
2369            e.splice_parameters(params, depth);
2370            Ok(())
2371        });
2372    }
2373
2374    /// Constructs a constant collection from specific rows and schema.
2375    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2376        let rows = rows
2377            .into_iter()
2378            .map(move |datums| Row::pack_slice(&datums))
2379            .collect();
2380        HirRelationExpr::Constant { rows, typ }
2381    }
2382
2383    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2384    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2385    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2386    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2387    /// finishing into a trivial finishing.
2388    pub fn finish_maintained(
2389        &mut self,
2390        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2391        group_size_hints: GroupSizeHints,
2392    ) {
2393        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2394            let old_finishing = mem::replace(
2395                finishing,
2396                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2397            );
2398            *self = HirRelationExpr::top_k(
2399                std::mem::replace(
2400                    self,
2401                    HirRelationExpr::Constant {
2402                        rows: vec![],
2403                        typ: SqlRelationType::new(Vec::new()),
2404                    },
2405                ),
2406                vec![],
2407                old_finishing.order_by,
2408                old_finishing.limit,
2409                old_finishing.offset,
2410                group_size_hints.limit_input_group_size,
2411            )
2412            .project(old_finishing.project);
2413        }
2414    }
2415
2416    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2417    ///
2418    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2419    /// parameter is not an HirScalarExpr anymore.)
2420    pub fn trivial_row_set_finishing_hir(
2421        arity: usize,
2422    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2423        RowSetFinishing {
2424            order_by: Vec::new(),
2425            limit: None,
2426            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2427            project: (0..arity).collect(),
2428        }
2429    }
2430
2431    /// True if the finishing does nothing to any result set.
2432    ///
2433    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2434    /// parameter is not an HirScalarExpr anymore.)
2435    pub fn is_trivial_row_set_finishing_hir(
2436        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2437        arity: usize,
2438    ) -> bool {
2439        rsf.limit.is_none()
2440            && rsf.order_by.is_empty()
2441            && rsf
2442                .offset
2443                .clone()
2444                .try_into_literal_int64()
2445                .is_ok_and(|o| o == 0)
2446            && rsf.project.iter().copied().eq(0..arity)
2447    }
2448
2449    /// The HirRelationExpr is considered potentially expensive if and only if
2450    /// at least one of the following conditions is true:
2451    ///
2452    ///  - It contains at least one HirScalarExpr with a function call.
2453    ///  - It contains at least one CallTable or a Reduce operator.
2454    ///  - We run into a RecursionLimitError while analyzing the expression.
2455    ///
2456    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2457    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2458    pub fn could_run_expensive_function(&self) -> bool {
2459        let mut result = false;
2460        self.visit_pre(&mut |e: &HirRelationExpr| {
2461            use HirRelationExpr::*;
2462            use HirScalarExpr::*;
2463
2464            e.visit_children(|scalar: &HirScalarExpr| {
2465                scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2466                    result |= match scalar {
2467                        Column(..)
2468                        | Literal(..)
2469                        | CallUnmaterializable(..)
2470                        | If { .. }
2471                        | Parameter(..)
2472                        | Select(..)
2473                        | Exists(..) => false,
2474                        // Function calls are considered expensive
2475                        CallUnary { .. }
2476                        | CallBinary { .. }
2477                        | CallVariadic { .. }
2478                        | Windowing(..) => true,
2479                    };
2480                })
2481            });
2482
2483            // CallTable has a table function; Reduce has an aggregate function.
2484            // Other constructs use MirScalarExpr to run a function
2485            result |= matches!(e, CallTable { .. } | Reduce { .. });
2486        });
2487
2488        result
2489    }
2490
2491    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2492    pub fn contains_temporal(&self) -> bool {
2493        let mut contains = false;
2494        self.visit_post(&mut |expr| {
2495            expr.visit_children(|expr: &HirScalarExpr| {
2496                contains = contains || expr.contains_temporal()
2497            })
2498        });
2499        contains
2500    }
2501
2502    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2503    pub fn contains_unmaterializable(&self) -> bool {
2504        let mut contains = false;
2505        self.visit_post(&mut |expr| {
2506            expr.visit_children(|expr: &HirScalarExpr| {
2507                contains = contains || expr.contains_unmaterializable()
2508            })
2509        });
2510        contains
2511    }
2512
2513    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
2514    /// [`UnmaterializableFunc::MzNow`].
2515    pub fn contains_unmaterializable_except_temporal(&self) -> bool {
2516        let mut contains = false;
2517        self.visit_post(&mut |expr| {
2518            expr.visit_children(|expr: &HirScalarExpr| {
2519                contains = contains || expr.contains_unmaterializable_except_temporal()
2520            })
2521        });
2522        contains
2523    }
2524}
2525
2526impl CollectionPlan for HirRelationExpr {
2527    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2528    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2529    /// (It does explore inside subqueries.)
2530    ///
2531    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2532    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2533    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2534        if let Self::Get {
2535            id: Id::Global(id), ..
2536        } = self
2537        {
2538            out.insert(*id);
2539        }
2540        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2541    }
2542}
2543
2544/// In addition to direct relation children of `self`, this also yields every
2545/// subquery (`Exists` / `Select`) reachable through `self`'s scalar children,
2546/// at any depth of scalar nesting. This is the asymmetry warned about on the
2547/// [`VisitChildren`] trait; the matching impl on `HirScalarExpr` deliberately
2548/// does not do the symmetric thing, to avoid mutual recursion.
2549impl VisitChildren<Self> for HirRelationExpr {
2550    fn visit_children<F>(&self, mut f: F)
2551    where
2552        F: FnMut(&Self),
2553    {
2554        // subqueries of type HirRelationExpr might be wrapped in
2555        // Exists or Select variants within HirScalarExpr trees
2556        // attached at the current node, and we want to visit them as well
2557        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2558            expr.visit_direct_subqueries(&mut f);
2559        });
2560
2561        use HirRelationExpr::*;
2562        match self {
2563            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2564            Let {
2565                name: _,
2566                id: _,
2567                value,
2568                body,
2569            } => {
2570                f(value);
2571                f(body);
2572            }
2573            LetRec {
2574                limit: _,
2575                bindings,
2576                body,
2577            } => {
2578                for (_, _, value, _) in bindings.iter() {
2579                    f(value);
2580                }
2581                f(body);
2582            }
2583            Project { input, outputs: _ } => f(input),
2584            Map { input, scalars: _ } => {
2585                f(input);
2586            }
2587            CallTable { func: _, exprs: _ } => (),
2588            Filter {
2589                input,
2590                predicates: _,
2591            } => {
2592                f(input);
2593            }
2594            Join {
2595                left,
2596                right,
2597                on: _,
2598                kind: _,
2599            } => {
2600                f(left);
2601                f(right);
2602            }
2603            Reduce {
2604                input,
2605                group_key: _,
2606                aggregates: _,
2607                expected_group_size: _,
2608            } => {
2609                f(input);
2610            }
2611            Distinct { input }
2612            | TopK {
2613                input,
2614                group_key: _,
2615                order_key: _,
2616                limit: _,
2617                offset: _,
2618                expected_group_size: _,
2619            }
2620            | Negate { input }
2621            | Threshold { input } => {
2622                f(input);
2623            }
2624            Union { base, inputs } => {
2625                f(base);
2626                for input in inputs {
2627                    f(input);
2628                }
2629            }
2630        }
2631    }
2632
2633    fn visit_mut_children<F>(&mut self, mut f: F)
2634    where
2635        F: FnMut(&mut Self),
2636    {
2637        // subqueries of type HirRelationExpr might be wrapped in
2638        // Exists or Select variants within HirScalarExpr trees
2639        // attached at the current node, and we want to visit them as well
2640        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2641            expr.visit_direct_subqueries_mut(&mut f);
2642        });
2643
2644        use HirRelationExpr::*;
2645        match self {
2646            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2647            Let {
2648                name: _,
2649                id: _,
2650                value,
2651                body,
2652            } => {
2653                f(value);
2654                f(body);
2655            }
2656            LetRec {
2657                limit: _,
2658                bindings,
2659                body,
2660            } => {
2661                for (_, _, value, _) in bindings.iter_mut() {
2662                    f(value);
2663                }
2664                f(body);
2665            }
2666            Project { input, outputs: _ } => f(input),
2667            Map { input, scalars: _ } => {
2668                f(input);
2669            }
2670            CallTable { func: _, exprs: _ } => (),
2671            Filter {
2672                input,
2673                predicates: _,
2674            } => {
2675                f(input);
2676            }
2677            Join {
2678                left,
2679                right,
2680                on: _,
2681                kind: _,
2682            } => {
2683                f(left);
2684                f(right);
2685            }
2686            Reduce {
2687                input,
2688                group_key: _,
2689                aggregates: _,
2690                expected_group_size: _,
2691            } => {
2692                f(input);
2693            }
2694            Distinct { input }
2695            | TopK {
2696                input,
2697                group_key: _,
2698                order_key: _,
2699                limit: _,
2700                offset: _,
2701                expected_group_size: _,
2702            }
2703            | Negate { input }
2704            | Threshold { input } => {
2705                f(input);
2706            }
2707            Union { base, inputs } => {
2708                f(base);
2709                for input in inputs {
2710                    f(input);
2711                }
2712            }
2713        }
2714    }
2715
2716    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2717    where
2718        F: FnMut(&Self) -> Result<(), E>,
2719    {
2720        // subqueries of type HirRelationExpr might be wrapped in
2721        // Exists or Select variants within HirScalarExpr trees
2722        // attached at the current node, and we want to visit them as well
2723        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2724            expr.try_visit_direct_subqueries(&mut f)
2725        })?;
2726
2727        use HirRelationExpr::*;
2728        match self {
2729            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2730            Let {
2731                name: _,
2732                id: _,
2733                value,
2734                body,
2735            } => {
2736                f(value)?;
2737                f(body)?;
2738            }
2739            LetRec {
2740                limit: _,
2741                bindings,
2742                body,
2743            } => {
2744                for (_, _, value, _) in bindings.iter() {
2745                    f(value)?;
2746                }
2747                f(body)?;
2748            }
2749            Project { input, outputs: _ } => f(input)?,
2750            Map { input, scalars: _ } => {
2751                f(input)?;
2752            }
2753            CallTable { func: _, exprs: _ } => (),
2754            Filter {
2755                input,
2756                predicates: _,
2757            } => {
2758                f(input)?;
2759            }
2760            Join {
2761                left,
2762                right,
2763                on: _,
2764                kind: _,
2765            } => {
2766                f(left)?;
2767                f(right)?;
2768            }
2769            Reduce {
2770                input,
2771                group_key: _,
2772                aggregates: _,
2773                expected_group_size: _,
2774            } => {
2775                f(input)?;
2776            }
2777            Distinct { input }
2778            | TopK {
2779                input,
2780                group_key: _,
2781                order_key: _,
2782                limit: _,
2783                offset: _,
2784                expected_group_size: _,
2785            }
2786            | Negate { input }
2787            | Threshold { input } => {
2788                f(input)?;
2789            }
2790            Union { base, inputs } => {
2791                f(base)?;
2792                for input in inputs {
2793                    f(input)?;
2794                }
2795            }
2796        }
2797        Ok(())
2798    }
2799
2800    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2801    where
2802        F: FnMut(&mut Self) -> Result<(), E>,
2803    {
2804        // subqueries of type HirRelationExpr might be wrapped in
2805        // Exists or Select variants within HirScalarExpr trees
2806        // attached at the current node, and we want to visit them as well
2807        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2808            expr.try_visit_direct_subqueries_mut(&mut f)
2809        })?;
2810
2811        use HirRelationExpr::*;
2812        match self {
2813            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2814            Let {
2815                name: _,
2816                id: _,
2817                value,
2818                body,
2819            } => {
2820                f(value)?;
2821                f(body)?;
2822            }
2823            LetRec {
2824                limit: _,
2825                bindings,
2826                body,
2827            } => {
2828                for (_, _, value, _) in bindings.iter_mut() {
2829                    f(value)?;
2830                }
2831                f(body)?;
2832            }
2833            Project { input, outputs: _ } => f(input)?,
2834            Map { input, scalars: _ } => {
2835                f(input)?;
2836            }
2837            CallTable { func: _, exprs: _ } => (),
2838            Filter {
2839                input,
2840                predicates: _,
2841            } => {
2842                f(input)?;
2843            }
2844            Join {
2845                left,
2846                right,
2847                on: _,
2848                kind: _,
2849            } => {
2850                f(left)?;
2851                f(right)?;
2852            }
2853            Reduce {
2854                input,
2855                group_key: _,
2856                aggregates: _,
2857                expected_group_size: _,
2858            } => {
2859                f(input)?;
2860            }
2861            Distinct { input }
2862            | TopK {
2863                input,
2864                group_key: _,
2865                order_key: _,
2866                limit: _,
2867                offset: _,
2868                expected_group_size: _,
2869            }
2870            | Negate { input }
2871            | Threshold { input } => {
2872                f(input)?;
2873            }
2874            Union { base, inputs } => {
2875                f(base)?;
2876                for input in inputs {
2877                    f(input)?;
2878                }
2879            }
2880        }
2881        Ok(())
2882    }
2883
2884    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a Self>
2885    where
2886        Self: 'a,
2887    {
2888        // we visit subqueries _first_, then the input
2889        let mut v: Vec<&HirRelationExpr> = vec![];
2890        use HirRelationExpr::*;
2891        match self {
2892            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2893            Let {
2894                name: _,
2895                id: _,
2896                value,
2897                body,
2898            } => {
2899                v.push(&*value);
2900                v.push(&*body);
2901            }
2902            LetRec {
2903                limit: _,
2904                bindings,
2905                body,
2906            } => {
2907                v.extend(bindings.iter().map(|(_, _, value, _)| value));
2908                v.push(&*body);
2909            }
2910            Map { input, scalars }
2911            | Filter {
2912                input,
2913                predicates: scalars,
2914            } => {
2915                for scalar in scalars {
2916                    v.append(&mut scalar.direct_subqueries());
2917                }
2918                v.push(&*input);
2919            }
2920            Reduce {
2921                input,
2922                group_key: _,
2923                aggregates,
2924                expected_group_size: _,
2925            } => {
2926                for agg in aggregates {
2927                    v.append(&mut agg.expr.direct_subqueries());
2928                }
2929                v.push(&*input);
2930            }
2931            TopK {
2932                input,
2933                group_key: _,
2934                order_key: _,
2935                limit,
2936                offset,
2937                expected_group_size: _,
2938            } => {
2939                if let Some(limit) = limit {
2940                    v.append(&mut limit.direct_subqueries());
2941                }
2942                v.append(&mut offset.direct_subqueries());
2943                v.push(&*input);
2944            }
2945            Project { input, outputs: _ }
2946            | Distinct { input }
2947            | Negate { input }
2948            | Threshold { input } => v.push(&*input),
2949            CallTable { func: _, exprs } => v.extend(
2950                exprs
2951                    .iter()
2952                    .map(|scalar| scalar.direct_subqueries())
2953                    .flatten(),
2954            ),
2955            Join {
2956                left,
2957                right,
2958                on,
2959                kind: _,
2960            } => {
2961                v.append(&mut on.direct_subqueries());
2962                v.push(&*left);
2963                v.push(&*right);
2964            }
2965            Union { base, inputs } => {
2966                v.push(&*base);
2967                v.extend(inputs.iter());
2968            }
2969        }
2970
2971        v.into_iter()
2972    }
2973
2974    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut Self>
2975    where
2976        Self: 'a,
2977    {
2978        // we visit subqueries _first_, then the input
2979        let mut v = vec![];
2980        use HirRelationExpr::*;
2981        match self {
2982            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2983            Let {
2984                name: _,
2985                id: _,
2986                value,
2987                body,
2988            } => {
2989                v.push(&mut **value);
2990                v.push(&mut **body);
2991            }
2992            LetRec {
2993                limit: _,
2994                bindings,
2995                body,
2996            } => {
2997                v.extend(bindings.iter_mut().map(|(_, _, value, _)| value));
2998                v.push(&mut **body);
2999            }
3000            Map { input, scalars }
3001            | Filter {
3002                input,
3003                predicates: scalars,
3004            } => {
3005                for scalar in scalars {
3006                    v.append(&mut scalar.direct_subqueries_mut());
3007                }
3008                v.push(&mut **input);
3009            }
3010            Reduce {
3011                input,
3012                group_key: _,
3013                aggregates,
3014                expected_group_size: _,
3015            } => {
3016                for agg in aggregates {
3017                    v.append(&mut agg.expr.direct_subqueries_mut());
3018                }
3019                v.push(&mut **input);
3020            }
3021            TopK {
3022                input,
3023                group_key: _,
3024                order_key: _,
3025                limit,
3026                offset,
3027                expected_group_size: _,
3028            } => {
3029                if let Some(limit) = limit {
3030                    v.append(&mut limit.direct_subqueries_mut());
3031                }
3032                v.append(&mut offset.direct_subqueries_mut());
3033                v.push(&mut **input);
3034            }
3035            Project { input, outputs: _ }
3036            | Distinct { input }
3037            | Negate { input }
3038            | Threshold { input } => v.push(&mut **input),
3039            CallTable { func: _, exprs } => v.extend(
3040                exprs
3041                    .iter_mut()
3042                    .map(|scalar| scalar.direct_subqueries_mut())
3043                    .flatten(),
3044            ),
3045            Join {
3046                left,
3047                right,
3048                on,
3049                kind: _,
3050            } => {
3051                v.append(&mut on.direct_subqueries_mut());
3052                v.push(&mut **left);
3053                v.push(&mut **right);
3054            }
3055            Union { base, inputs } => {
3056                v.push(&mut **base);
3057                v.extend(inputs.iter_mut());
3058            }
3059        }
3060
3061        v.into_iter()
3062    }
3063}
3064
3065/// Yields the scalars directly attached to relation nodes (e.g. `Map.scalars`,
3066/// `Filter.predicates`, `Join.on`, `Reduce` aggregate args, `TopK.{limit,
3067/// offset}`, `CallTable.exprs`); does not descend into them.
3068impl VisitChildren<HirScalarExpr> for HirRelationExpr {
3069    fn visit_children<F>(&self, mut f: F)
3070    where
3071        F: FnMut(&HirScalarExpr),
3072    {
3073        use HirRelationExpr::*;
3074        match self {
3075            Constant { rows: _, typ: _ }
3076            | Get { id: _, typ: _ }
3077            | Let {
3078                name: _,
3079                id: _,
3080                value: _,
3081                body: _,
3082            }
3083            | LetRec {
3084                limit: _,
3085                bindings: _,
3086                body: _,
3087            }
3088            | Project {
3089                input: _,
3090                outputs: _,
3091            } => (),
3092            Map { input: _, scalars } => {
3093                for scalar in scalars {
3094                    f(scalar);
3095                }
3096            }
3097            CallTable { func: _, exprs } => {
3098                for expr in exprs {
3099                    f(expr);
3100                }
3101            }
3102            Filter {
3103                input: _,
3104                predicates,
3105            } => {
3106                for predicate in predicates {
3107                    f(predicate);
3108                }
3109            }
3110            Join {
3111                left: _,
3112                right: _,
3113                on,
3114                kind: _,
3115            } => f(on),
3116            Reduce {
3117                input: _,
3118                group_key: _,
3119                aggregates,
3120                expected_group_size: _,
3121            } => {
3122                for aggregate in aggregates {
3123                    f(aggregate.expr.as_ref());
3124                }
3125            }
3126            TopK {
3127                input: _,
3128                group_key: _,
3129                order_key: _,
3130                limit,
3131                offset,
3132                expected_group_size: _,
3133            } => {
3134                if let Some(limit) = limit {
3135                    f(limit)
3136                }
3137                f(offset)
3138            }
3139            Distinct { input: _ }
3140            | Negate { input: _ }
3141            | Threshold { input: _ }
3142            | Union { base: _, inputs: _ } => (),
3143        }
3144    }
3145
3146    fn visit_mut_children<F>(&mut self, mut f: F)
3147    where
3148        F: FnMut(&mut HirScalarExpr),
3149    {
3150        use HirRelationExpr::*;
3151        match self {
3152            Constant { rows: _, typ: _ }
3153            | Get { id: _, typ: _ }
3154            | Let {
3155                name: _,
3156                id: _,
3157                value: _,
3158                body: _,
3159            }
3160            | LetRec {
3161                limit: _,
3162                bindings: _,
3163                body: _,
3164            }
3165            | Project {
3166                input: _,
3167                outputs: _,
3168            } => (),
3169            Map { input: _, scalars } => {
3170                for scalar in scalars {
3171                    f(scalar);
3172                }
3173            }
3174            CallTable { func: _, exprs } => {
3175                for expr in exprs {
3176                    f(expr);
3177                }
3178            }
3179            Filter {
3180                input: _,
3181                predicates,
3182            } => {
3183                for predicate in predicates {
3184                    f(predicate);
3185                }
3186            }
3187            Join {
3188                left: _,
3189                right: _,
3190                on,
3191                kind: _,
3192            } => f(on),
3193            Reduce {
3194                input: _,
3195                group_key: _,
3196                aggregates,
3197                expected_group_size: _,
3198            } => {
3199                for aggregate in aggregates {
3200                    f(aggregate.expr.as_mut());
3201                }
3202            }
3203            TopK {
3204                input: _,
3205                group_key: _,
3206                order_key: _,
3207                limit,
3208                offset,
3209                expected_group_size: _,
3210            } => {
3211                if let Some(limit) = limit {
3212                    f(limit)
3213                }
3214                f(offset)
3215            }
3216            Distinct { input: _ }
3217            | Negate { input: _ }
3218            | Threshold { input: _ }
3219            | Union { base: _, inputs: _ } => (),
3220        }
3221    }
3222
3223    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3224    where
3225        F: FnMut(&HirScalarExpr) -> Result<(), E>,
3226    {
3227        use HirRelationExpr::*;
3228        match self {
3229            Constant { rows: _, typ: _ }
3230            | Get { id: _, typ: _ }
3231            | Let {
3232                name: _,
3233                id: _,
3234                value: _,
3235                body: _,
3236            }
3237            | LetRec {
3238                limit: _,
3239                bindings: _,
3240                body: _,
3241            }
3242            | Project {
3243                input: _,
3244                outputs: _,
3245            } => (),
3246            Map { input: _, scalars } => {
3247                for scalar in scalars {
3248                    f(scalar)?;
3249                }
3250            }
3251            CallTable { func: _, exprs } => {
3252                for expr in exprs {
3253                    f(expr)?;
3254                }
3255            }
3256            Filter {
3257                input: _,
3258                predicates,
3259            } => {
3260                for predicate in predicates {
3261                    f(predicate)?;
3262                }
3263            }
3264            Join {
3265                left: _,
3266                right: _,
3267                on,
3268                kind: _,
3269            } => f(on)?,
3270            Reduce {
3271                input: _,
3272                group_key: _,
3273                aggregates,
3274                expected_group_size: _,
3275            } => {
3276                for aggregate in aggregates {
3277                    f(aggregate.expr.as_ref())?;
3278                }
3279            }
3280            TopK {
3281                input: _,
3282                group_key: _,
3283                order_key: _,
3284                limit,
3285                offset,
3286                expected_group_size: _,
3287            } => {
3288                if let Some(limit) = limit {
3289                    f(limit)?
3290                }
3291                f(offset)?
3292            }
3293            Distinct { input: _ }
3294            | Negate { input: _ }
3295            | Threshold { input: _ }
3296            | Union { base: _, inputs: _ } => (),
3297        }
3298        Ok(())
3299    }
3300
3301    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3302    where
3303        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3304    {
3305        use HirRelationExpr::*;
3306        match self {
3307            Constant { rows: _, typ: _ }
3308            | Get { id: _, typ: _ }
3309            | Let {
3310                name: _,
3311                id: _,
3312                value: _,
3313                body: _,
3314            }
3315            | LetRec {
3316                limit: _,
3317                bindings: _,
3318                body: _,
3319            }
3320            | Project {
3321                input: _,
3322                outputs: _,
3323            } => (),
3324            Map { input: _, scalars } => {
3325                for scalar in scalars {
3326                    f(scalar)?;
3327                }
3328            }
3329            CallTable { func: _, exprs } => {
3330                for expr in exprs {
3331                    f(expr)?;
3332                }
3333            }
3334            Filter {
3335                input: _,
3336                predicates,
3337            } => {
3338                for predicate in predicates {
3339                    f(predicate)?;
3340                }
3341            }
3342            Join {
3343                left: _,
3344                right: _,
3345                on,
3346                kind: _,
3347            } => f(on)?,
3348            Reduce {
3349                input: _,
3350                group_key: _,
3351                aggregates,
3352                expected_group_size: _,
3353            } => {
3354                for aggregate in aggregates {
3355                    f(aggregate.expr.as_mut())?;
3356                }
3357            }
3358            TopK {
3359                input: _,
3360                group_key: _,
3361                order_key: _,
3362                limit,
3363                offset,
3364                expected_group_size: _,
3365            } => {
3366                if let Some(limit) = limit {
3367                    f(limit)?
3368                }
3369                f(offset)?
3370            }
3371            Distinct { input: _ }
3372            | Negate { input: _ }
3373            | Threshold { input: _ }
3374            | Union { base: _, inputs: _ } => (),
3375        }
3376        Ok(())
3377    }
3378
3379    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirScalarExpr>
3380    where
3381        HirScalarExpr: 'a,
3382    {
3383        use HirRelationExpr::*;
3384        match self {
3385            Constant { rows: _, typ: _ }
3386            | Get { id: _, typ: _ }
3387            | Let {
3388                name: _,
3389                id: _,
3390                value: _,
3391                body: _,
3392            }
3393            | LetRec {
3394                limit: _,
3395                bindings: _,
3396                body: _,
3397            }
3398            | Project {
3399                input: _,
3400                outputs: _,
3401            }
3402            | Distinct { input: _ }
3403            | Negate { input: _ }
3404            | Threshold { input: _ }
3405            | Union { base: _, inputs: _ } => vec![],
3406            Map { input: _, scalars }
3407            | CallTable {
3408                func: _,
3409                exprs: scalars,
3410            }
3411            | Filter {
3412                input: _,
3413                predicates: scalars,
3414            } => scalars.iter().collect(),
3415            Join {
3416                left: _,
3417                right: _,
3418                on,
3419                kind: _,
3420            } => vec![on],
3421            Reduce {
3422                input: _,
3423                group_key: _,
3424                aggregates,
3425                expected_group_size: _,
3426            } => aggregates.iter().map(|agg| &*agg.expr).collect(),
3427            TopK {
3428                input: _,
3429                group_key: _,
3430                order_key: _,
3431                limit,
3432                offset,
3433                expected_group_size: _,
3434            } => limit.iter().chain(std::iter::once(offset)).collect(),
3435        }
3436        .into_iter()
3437    }
3438
3439    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirScalarExpr>
3440    where
3441        HirScalarExpr: 'a,
3442    {
3443        use HirRelationExpr::*;
3444        match self {
3445            Constant { rows: _, typ: _ }
3446            | Get { id: _, typ: _ }
3447            | Let {
3448                name: _,
3449                id: _,
3450                value: _,
3451                body: _,
3452            }
3453            | LetRec {
3454                limit: _,
3455                bindings: _,
3456                body: _,
3457            }
3458            | Project {
3459                input: _,
3460                outputs: _,
3461            }
3462            | Distinct { input: _ }
3463            | Negate { input: _ }
3464            | Threshold { input: _ }
3465            | Union { base: _, inputs: _ } => vec![],
3466            Map { input: _, scalars }
3467            | CallTable {
3468                func: _,
3469                exprs: scalars,
3470            }
3471            | Filter {
3472                input: _,
3473                predicates: scalars,
3474            } => scalars.iter_mut().collect(),
3475            Join {
3476                left: _,
3477                right: _,
3478                on,
3479                kind: _,
3480            } => vec![on],
3481            Reduce {
3482                input: _,
3483                group_key: _,
3484                aggregates,
3485                expected_group_size: _,
3486            } => aggregates.iter_mut().map(|agg| &mut *agg.expr).collect(),
3487            TopK {
3488                input: _,
3489                group_key: _,
3490                order_key: _,
3491                limit,
3492                offset,
3493                expected_group_size: _,
3494            } => limit.iter_mut().chain(std::iter::once(offset)).collect(),
3495        }
3496        .into_iter()
3497    }
3498}
3499
3500impl HirScalarExpr {
3501    pub fn name(&self) -> Option<Arc<str>> {
3502        use HirScalarExpr::*;
3503        match self {
3504            Column(_, name)
3505            | Parameter(_, name)
3506            | Literal(_, _, name)
3507            | CallUnmaterializable(_, name)
3508            | CallUnary { name, .. }
3509            | CallBinary { name, .. }
3510            | CallVariadic { name, .. }
3511            | If { name, .. }
3512            | Exists(_, name)
3513            | Select(_, name)
3514            | Windowing(_, name) => name.0.clone(),
3515        }
3516    }
3517
3518    /// Visit every subquery, without descending into subqueries.
3519    pub fn visit_direct_subqueries<F>(&self, mut f: F)
3520    where
3521        F: FnMut(&HirRelationExpr),
3522    {
3523        self.visit_post(&mut |e| {
3524            VisitChildren::<HirRelationExpr>::visit_children(e, &mut f);
3525        });
3526    }
3527
3528    /// Mutable counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3529    pub fn visit_direct_subqueries_mut<F>(&mut self, mut f: F)
3530    where
3531        F: FnMut(&mut HirRelationExpr),
3532    {
3533        self.visit_mut_post(&mut |e| {
3534            VisitChildren::<HirRelationExpr>::visit_mut_children(e, &mut f);
3535        });
3536    }
3537
3538    /// Fallible counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3539    pub fn try_visit_direct_subqueries<F, E>(&self, mut f: F) -> Result<(), E>
3540    where
3541        F: FnMut(&HirRelationExpr) -> Result<(), E>,
3542    {
3543        self.try_visit_post(&mut |e| {
3544            VisitChildren::<HirRelationExpr>::try_visit_children(e, &mut f)
3545        })
3546    }
3547
3548    /// Fallible mutable counterpart of [`HirScalarExpr::visit_direct_subqueries`].
3549    pub fn try_visit_direct_subqueries_mut<F, E>(&mut self, mut f: F) -> Result<(), E>
3550    where
3551        F: FnMut(&mut HirRelationExpr) -> Result<(), E>,
3552    {
3553        self.try_visit_mut_post(&mut |e| {
3554            VisitChildren::<HirRelationExpr>::try_visit_mut_children(e, &mut f)
3555        })
3556    }
3557
3558    /// Replaces parameter references in the expression with the corresponding datum from `params`.
3559    /// Additionally, it simplifies OFFSET clauses to constants after parameter binding, and checks
3560    /// them for non-negativity.
3561    ///
3562    /// This handles subqueries nested inside `self` by calling back into
3563    /// [`HirRelationExpr::bind_parameters_and_simplify_offset`] on each direct subquery; that
3564    /// relation-level function in turn calls back here for the scalar expressions it holds, so
3565    /// the two functions corecursively cover the entire HIR tree.
3566    pub fn bind_parameters_and_simplify_offset(
3567        &mut self,
3568        scx: &StatementContext,
3569        lifetime: QueryLifetime,
3570        params: &Params,
3571    ) -> Result<(), PlanError> {
3572        // First, rewrite each `Parameter` node to a literal. This walks only the scalar tree;
3573        // it does not yet descend into subqueries.
3574        self.try_visit_mut_post(&mut |e: &mut HirScalarExpr| {
3575            if let HirScalarExpr::Parameter(n, name) = e {
3576                let datum = match params.datums.iter().nth(*n - 1) {
3577                    None => return Err(PlanError::UnknownParameter(*n)),
3578                    Some(datum) => datum,
3579                };
3580                let scalar_type = &params.execute_types[*n - 1];
3581                let row = Row::pack([datum]);
3582                let column_type = scalar_type.clone().nullable(datum.is_null());
3583
3584                let name = if let Some(name) = &name.0 {
3585                    Some(Arc::clone(name))
3586                } else {
3587                    Some(Arc::from(format!("${n}")))
3588                };
3589
3590                let qcx = QueryContext::root(scx, lifetime);
3591                let ecx = execute_expr_context(&qcx);
3592
3593                *e = plan_cast(
3594                    &ecx,
3595                    *EXECUTE_CAST_CONTEXT,
3596                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3597                    &params.expected_types[*n - 1],
3598                )
3599                .expect("checked in plan_params");
3600            }
3601            Ok(())
3602        })?;
3603        // Then descend into any subqueries; the relation-side `bind_parameters_and_simplify_offset`
3604        // handles corecursion back into scalars.
3605        self.try_visit_direct_subqueries_mut(|r: &mut HirRelationExpr| {
3606            r.bind_parameters_and_simplify_offset(scx, lifetime, params)
3607        })
3608    }
3609
3610    /// Like [`HirScalarExpr::bind_parameters_and_simplify_offset`], except that parameters are
3611    /// replaced with the corresponding expression fragment from `params` rather
3612    /// than a datum.
3613    ///
3614    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3615    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3616    /// in `self` that refer to invalid indices of `params` will cause a panic.
3617    ///
3618    /// Column references in parameters will be corrected to account for the
3619    /// depth at which they are spliced.
3620    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3621        #[allow(deprecated)]
3622        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3623                                                        e: &mut HirScalarExpr|
3624         -> Result<(), ()> {
3625            if let HirScalarExpr::Parameter(i, _name) = e {
3626                *e = params[*i - 1].clone();
3627                // Correct any column references in the parameter expression for
3628                // its new depth.
3629                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3630                    if col.level >= d {
3631                        col.level += depth
3632                    }
3633                });
3634            }
3635            Ok(())
3636        });
3637    }
3638
3639    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3640    pub fn contains_temporal(&self) -> bool {
3641        let mut contains = false;
3642        self.visit_post(&mut |e| {
3643            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3644                contains = true;
3645            }
3646        });
3647        contains
3648    }
3649
3650    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3651    pub fn contains_unmaterializable(&self) -> bool {
3652        let mut contains = false;
3653        self.visit_post(&mut |e| {
3654            if let Self::CallUnmaterializable(_, _) = e {
3655                contains = true;
3656            }
3657        });
3658        contains
3659    }
3660
3661    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
3662    /// [`UnmaterializableFunc::MzNow`].
3663    pub fn contains_unmaterializable_except_temporal(&self) -> bool {
3664        let mut contains = false;
3665        self.visit_post(&mut |e| {
3666            if let Self::CallUnmaterializable(f, _) = e {
3667                if *f != UnmaterializableFunc::MzNow {
3668                    contains = true;
3669                }
3670            }
3671        });
3672        contains
3673    }
3674
3675    /// Constructs an unnamed column reference in the current scope.
3676    /// Use [`HirScalarExpr::named_column`] when a name is known.
3677    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3678    pub fn column(index: usize) -> HirScalarExpr {
3679        HirScalarExpr::Column(
3680            ColumnRef {
3681                level: 0,
3682                column: index,
3683            },
3684            TreatAsEqual(None),
3685        )
3686    }
3687
3688    /// Constructs an unnamed column reference.
3689    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3690        HirScalarExpr::Column(cr, TreatAsEqual(None))
3691    }
3692
3693    /// Constructs a named column reference.
3694    /// Names are interned by a `NameManager`.
3695    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3696        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3697    }
3698
3699    pub fn parameter(n: usize) -> HirScalarExpr {
3700        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3701    }
3702
3703    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3704        let col_type = scalar_type.nullable(datum.is_null());
3705        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3706        let row = Row::pack([datum]);
3707        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3708    }
3709
3710    pub fn literal_true() -> HirScalarExpr {
3711        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3712    }
3713
3714    pub fn literal_false() -> HirScalarExpr {
3715        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3716    }
3717
3718    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3719        HirScalarExpr::literal(Datum::Null, scalar_type)
3720    }
3721
3722    pub fn literal_1d_array(
3723        datums: Vec<Datum>,
3724        element_scalar_type: SqlScalarType,
3725    ) -> Result<HirScalarExpr, PlanError> {
3726        let scalar_type = match element_scalar_type {
3727            SqlScalarType::Array(_) => {
3728                sql_bail!("cannot build array from array type");
3729            }
3730            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3731        };
3732
3733        let mut row = Row::default();
3734        row.packer()
3735            .try_push_array(
3736                &[ArrayDimension {
3737                    lower_bound: 1,
3738                    length: datums.len(),
3739                }],
3740                datums,
3741            )
3742            .expect("array constructed to be valid");
3743
3744        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3745    }
3746
3747    pub fn as_literal(&self) -> Option<Datum<'_>> {
3748        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3749            Some(row.unpack_first())
3750        } else {
3751            None
3752        }
3753    }
3754
3755    pub fn is_literal_true(&self) -> bool {
3756        Some(Datum::True) == self.as_literal()
3757    }
3758
3759    pub fn is_literal_false(&self) -> bool {
3760        Some(Datum::False) == self.as_literal()
3761    }
3762
3763    pub fn is_literal_null(&self) -> bool {
3764        Some(Datum::Null) == self.as_literal()
3765    }
3766
3767    /// Return true iff `self` consists only of literals, materializable function calls, and
3768    /// if-else statements.
3769    pub fn is_constant(&self) -> bool {
3770        let mut worklist = vec![self];
3771        while let Some(expr) = worklist.pop() {
3772            match expr {
3773                Self::Literal(..) => {
3774                    // leaf node, do nothing
3775                }
3776                Self::CallUnary { expr, .. } => {
3777                    worklist.push(expr);
3778                }
3779                Self::CallBinary {
3780                    func: _,
3781                    expr1,
3782                    expr2,
3783                    name: _,
3784                } => {
3785                    worklist.push(expr1);
3786                    worklist.push(expr2);
3787                }
3788                Self::CallVariadic {
3789                    func: _,
3790                    exprs,
3791                    name: _,
3792                } => {
3793                    worklist.extend(exprs.iter());
3794                }
3795                // (CallUnmaterializable is not allowed)
3796                Self::If {
3797                    cond,
3798                    then,
3799                    els,
3800                    name: _,
3801                } => {
3802                    worklist.push(cond);
3803                    worklist.push(then);
3804                    worklist.push(els);
3805                }
3806                _ => {
3807                    return false; // Any other node makes `self` non-constant.
3808                }
3809            }
3810        }
3811        true
3812    }
3813
3814    pub fn call_unary(self, func: UnaryFunc) -> Self {
3815        HirScalarExpr::CallUnary {
3816            func,
3817            expr: Box::new(self),
3818            name: NameMetadata::default(),
3819        }
3820    }
3821
3822    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3823        HirScalarExpr::CallBinary {
3824            func: func.into(),
3825            expr1: Box::new(self),
3826            expr2: Box::new(other),
3827            name: NameMetadata::default(),
3828        }
3829    }
3830
3831    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3832        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3833    }
3834
3835    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
3836        HirScalarExpr::CallVariadic {
3837            func: func.into(),
3838            exprs,
3839            name: NameMetadata::default(),
3840        }
3841    }
3842
3843    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3844        HirScalarExpr::If {
3845            cond: Box::new(cond),
3846            then: Box::new(then),
3847            els: Box::new(els),
3848            name: NameMetadata::default(),
3849        }
3850    }
3851
3852    pub fn windowing(expr: WindowExpr) -> Self {
3853        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3854    }
3855
3856    pub fn or(self, other: Self) -> Self {
3857        HirScalarExpr::call_variadic(Or, vec![self, other])
3858    }
3859
3860    pub fn and(self, other: Self) -> Self {
3861        HirScalarExpr::call_variadic(And, vec![self, other])
3862    }
3863
3864    pub fn not(self) -> Self {
3865        self.call_unary(UnaryFunc::Not(func::Not))
3866    }
3867
3868    pub fn call_is_null(self) -> Self {
3869        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3870    }
3871
3872    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3873    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3874        match args.len() {
3875            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3876            1 => args.swap_remove(0),
3877            _ => HirScalarExpr::call_variadic(And, args),
3878        }
3879    }
3880
3881    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3882    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3883        match args.len() {
3884            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3885            1 => args.swap_remove(0),
3886            _ => HirScalarExpr::call_variadic(Or, args),
3887        }
3888    }
3889
3890    pub fn take(&mut self) -> Self {
3891        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3892    }
3893
3894    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3895    /// Visits the column references in this scalar expression.
3896    ///
3897    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3898    /// which will be incremented with each subquery entered and presented to the supplied
3899    /// function `f`.
3900    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3901    where
3902        F: FnMut(usize, &ColumnRef),
3903    {
3904        #[allow(deprecated)]
3905        let _ = self.visit_recursively(depth, &mut |depth: usize,
3906                                                    e: &HirScalarExpr|
3907         -> Result<(), ()> {
3908            if let HirScalarExpr::Column(col, _name) = e {
3909                f(depth, col)
3910            }
3911            Ok(())
3912        });
3913    }
3914
3915    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3916    /// Like `visit_columns`, but permits mutating the column references.
3917    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3918    where
3919        F: FnMut(usize, &mut ColumnRef),
3920    {
3921        #[allow(deprecated)]
3922        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3923                                                        e: &mut HirScalarExpr|
3924         -> Result<(), ()> {
3925            if let HirScalarExpr::Column(col, _name) = e {
3926                f(depth, col)
3927            }
3928            Ok(())
3929        });
3930    }
3931
3932    /// Visits those column references in this scalar expression that refer to the root
3933    /// level. These include column references that are at the root level, as well as column
3934    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3935    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3936    /// "root level" to be `self`'s level.)
3937    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3938    where
3939        F: FnMut(usize),
3940    {
3941        #[allow(deprecated)]
3942        let _ = self.visit_recursively(0, &mut |depth: usize,
3943                                                e: &HirScalarExpr|
3944         -> Result<(), ()> {
3945            if let HirScalarExpr::Column(col, _name) = e {
3946                if col.level == depth {
3947                    f(col.column)
3948                }
3949            }
3950            Ok(())
3951        });
3952    }
3953
3954    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3955    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3956    where
3957        F: FnMut(&mut usize),
3958    {
3959        #[allow(deprecated)]
3960        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3961                                                    e: &mut HirScalarExpr|
3962         -> Result<(), ()> {
3963            if let HirScalarExpr::Column(col, _name) = e {
3964                if col.level == depth {
3965                    f(&mut col.column)
3966                }
3967            }
3968            Ok(())
3969        });
3970    }
3971
3972    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3973    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3974    /// in them. It takes the current depth of the expression and increases it when
3975    /// entering a subquery.
3976    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3977    where
3978        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3979    {
3980        match self {
3981            HirScalarExpr::Literal(..)
3982            | HirScalarExpr::Parameter(..)
3983            | HirScalarExpr::CallUnmaterializable(..)
3984            | HirScalarExpr::Column(..) => (),
3985            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3986            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3987                expr1.visit_recursively(depth, f)?;
3988                expr2.visit_recursively(depth, f)?;
3989            }
3990            HirScalarExpr::CallVariadic { exprs, .. } => {
3991                for expr in exprs {
3992                    expr.visit_recursively(depth, f)?;
3993                }
3994            }
3995            HirScalarExpr::If {
3996                cond,
3997                then,
3998                els,
3999                name: _,
4000            } => {
4001                cond.visit_recursively(depth, f)?;
4002                then.visit_recursively(depth, f)?;
4003                els.visit_recursively(depth, f)?;
4004            }
4005            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
4006                #[allow(deprecated)]
4007                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
4008                    e.visit_recursively(depth, f)
4009                })?;
4010            }
4011            HirScalarExpr::Windowing(expr, _name) => {
4012                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
4013            }
4014        }
4015        f(depth, self)
4016    }
4017
4018    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
4019    /// Like `visit_recursively`, but permits mutating the scalar expressions.
4020    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
4021    where
4022        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
4023    {
4024        match self {
4025            HirScalarExpr::Literal(..)
4026            | HirScalarExpr::Parameter(..)
4027            | HirScalarExpr::CallUnmaterializable(..)
4028            | HirScalarExpr::Column(..) => (),
4029            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
4030            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
4031                expr1.visit_recursively_mut(depth, f)?;
4032                expr2.visit_recursively_mut(depth, f)?;
4033            }
4034            HirScalarExpr::CallVariadic { exprs, .. } => {
4035                for expr in exprs {
4036                    expr.visit_recursively_mut(depth, f)?;
4037                }
4038            }
4039            HirScalarExpr::If {
4040                cond,
4041                then,
4042                els,
4043                name: _,
4044            } => {
4045                cond.visit_recursively_mut(depth, f)?;
4046                then.visit_recursively_mut(depth, f)?;
4047                els.visit_recursively_mut(depth, f)?;
4048            }
4049            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
4050                #[allow(deprecated)]
4051                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
4052                    e.visit_recursively_mut(depth, f)
4053                })?;
4054            }
4055            HirScalarExpr::Windowing(expr, _name) => {
4056                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
4057            }
4058        }
4059        f(depth, self)
4060    }
4061
4062    /// Attempts to simplify self into a literal.
4063    ///
4064    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
4065    /// an evaluation error occurs during simplification, or if self contains
4066    /// - a subquery
4067    /// - a column reference to an outer level
4068    /// - a parameter
4069    /// - a window function call
4070    fn simplify_to_literal(self) -> Option<Row> {
4071        let mut expr = self
4072            .lower_uncorrelated(crate::plan::lowering::Config::default())
4073            .ok()?;
4074        // Using MIR evaluation with repr types is fine here: the
4075        // result is an untyped Row, so any intermediate type
4076        // canonicalization is discarded.
4077        expr.reduce(&[]);
4078        match expr {
4079            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
4080            _ => None,
4081        }
4082    }
4083
4084    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
4085    /// or an evaluation error occurs during simplification), it returns
4086    /// [`PlanError::ConstantExpressionSimplificationFailed`].
4087    ///
4088    /// The returned error is an _internal_ error if the expression contains
4089    /// - a subquery
4090    /// - a column reference to an outer level
4091    /// - a parameter
4092    /// - a window function call
4093    ///
4094    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
4095    /// msg.
4096    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
4097        let mut expr = self
4098            .lower_uncorrelated(crate::plan::lowering::Config::default())
4099            .map_err(|err| {
4100                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
4101            })?;
4102        // Using MIR evaluation with repr types is fine here: the
4103        // result is an untyped Row, so any intermediate type
4104        // canonicalization is discarded.
4105        expr.reduce(&[]);
4106        match expr {
4107            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
4108            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
4109                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
4110            ),
4111            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
4112                "Not a constant".to_string(),
4113            )),
4114        }
4115    }
4116
4117    /// Attempts to simplify this expression to a literal 64-bit integer.
4118    ///
4119    /// Returns `None` if this expression cannot be simplified, e.g. because it
4120    /// contains non-literal values.
4121    ///
4122    /// # Panics
4123    ///
4124    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
4125    pub fn into_literal_int64(self) -> Option<i64> {
4126        self.simplify_to_literal().and_then(|row| {
4127            let datum = row.unpack_first();
4128            if datum.is_null() {
4129                None
4130            } else {
4131                Some(datum.unwrap_int64())
4132            }
4133        })
4134    }
4135
4136    /// Attempts to simplify this expression to a literal string.
4137    ///
4138    /// Returns `None` if this expression cannot be simplified, e.g. because it
4139    /// contains non-literal values.
4140    ///
4141    /// # Panics
4142    ///
4143    /// Panics if this expression does not have type [`SqlScalarType::String`].
4144    pub fn into_literal_string(self) -> Option<String> {
4145        self.simplify_to_literal().and_then(|row| {
4146            let datum = row.unpack_first();
4147            if datum.is_null() {
4148                None
4149            } else {
4150                Some(datum.unwrap_str().to_owned())
4151            }
4152        })
4153    }
4154
4155    /// Attempts to simplify this expression to a literal MzTimestamp.
4156    ///
4157    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
4158    /// simplified, e.g. because it contains non-literal values or a cast fails.
4159    ///
4160    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
4161    /// error when it fails. (E.g., there can be non-trivial cast errors.)
4162    /// See `try_into_literal_int64` as an example.
4163    ///
4164    /// # Panics
4165    ///
4166    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
4167    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
4168        self.simplify_to_literal().and_then(|row| {
4169            let datum = row.unpack_first();
4170            if datum.is_null() {
4171                None
4172            } else {
4173                Some(datum.unwrap_mz_timestamp())
4174            }
4175        })
4176    }
4177
4178    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
4179    /// returns it as an i64.
4180    ///
4181    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
4182    /// - it's not a constant expression (as determined by `is_constant`)
4183    /// - evaluates to null
4184    /// - an EvalError occurs during evaluation (e.g., a cast fails)
4185    ///
4186    /// # Panics
4187    ///
4188    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
4189    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
4190        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
4191        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
4192        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
4193        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
4194        // preconditions also of all the other into_literal_... functions.)
4195        if !self.is_constant() {
4196            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
4197                "Expected a constant expression, got {}",
4198                self
4199            )));
4200        }
4201        self.clone()
4202            .simplify_to_literal_with_result()
4203            .and_then(|row| {
4204                let datum = row.unpack_first();
4205                if datum.is_null() {
4206                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
4207                        "Expected an expression that evaluates to a non-null value, got {}",
4208                        self
4209                    )))
4210                } else {
4211                    Ok(datum.unwrap_int64())
4212                }
4213            })
4214    }
4215
4216    pub fn contains_parameters(&self) -> bool {
4217        let mut contains_parameters = false;
4218        #[allow(deprecated)]
4219        let _ = self.visit_recursively(0, &mut |_depth: usize,
4220                                                expr: &HirScalarExpr|
4221         -> Result<(), ()> {
4222            if let HirScalarExpr::Parameter(..) = expr {
4223                contains_parameters = true;
4224            }
4225            Ok(())
4226        });
4227        contains_parameters
4228    }
4229
4230    fn direct_subqueries(&self) -> Vec<&HirRelationExpr> {
4231        let mut subqueries: Vec<&HirRelationExpr> = vec![];
4232
4233        let mut worklist = vec![self];
4234        while let Some(elt) = worklist.pop() {
4235            match elt {
4236                HirScalarExpr::Column(_, _)
4237                | HirScalarExpr::Parameter(_, _)
4238                | HirScalarExpr::Literal(_, _, _)
4239                | HirScalarExpr::CallUnmaterializable(_, _) => (),
4240                HirScalarExpr::CallUnary {
4241                    func: _,
4242                    expr,
4243                    name: _,
4244                } => worklist.push(&*expr),
4245                HirScalarExpr::CallBinary {
4246                    func: _,
4247                    expr1,
4248                    expr2,
4249                    name: _name,
4250                } => {
4251                    // Push in reverse so children pop (and are visited) left-to-right.
4252                    worklist.push(&*expr2);
4253                    worklist.push(&*expr1);
4254                }
4255                HirScalarExpr::CallVariadic {
4256                    func: _,
4257                    exprs,
4258                    name: _name,
4259                } => {
4260                    worklist.extend(exprs.iter().rev());
4261                }
4262                HirScalarExpr::If {
4263                    cond,
4264                    then,
4265                    els,
4266                    name: _,
4267                } => {
4268                    worklist.push(&*els);
4269                    worklist.push(&*then);
4270                    worklist.push(&*cond);
4271                }
4272                HirScalarExpr::Exists(hir, _) | HirScalarExpr::Select(hir, _) => {
4273                    subqueries.push(&*hir);
4274                }
4275                HirScalarExpr::Windowing(
4276                    WindowExpr {
4277                        func,
4278                        partition_by,
4279                        order_by,
4280                    },
4281                    _,
4282                ) => {
4283                    // Push in reverse so children pop (and are visited) left-to-right:
4284                    // func args, then partition_by, then order_by.
4285                    worklist.extend(order_by.iter().rev());
4286                    worklist.extend(partition_by.iter().rev());
4287                    match func {
4288                        WindowExprType::Scalar(_) => (),
4289                        WindowExprType::Value(val) => worklist.push(&*val.args),
4290                        WindowExprType::Aggregate(agg) => worklist.push(&*agg.aggregate_expr.expr),
4291                    }
4292                }
4293            }
4294        }
4295
4296        subqueries
4297    }
4298
4299    fn direct_subqueries_mut(&mut self) -> Vec<&mut HirRelationExpr> {
4300        let mut subqueries: Vec<&mut HirRelationExpr> = vec![];
4301
4302        let mut worklist = vec![self];
4303        while let Some(elt) = worklist.pop() {
4304            match elt {
4305                HirScalarExpr::Column(_, _)
4306                | HirScalarExpr::Parameter(_, _)
4307                | HirScalarExpr::Literal(_, _, _)
4308                | HirScalarExpr::CallUnmaterializable(_, _) => (),
4309                HirScalarExpr::CallUnary {
4310                    func: _,
4311                    expr,
4312                    name: _,
4313                } => worklist.push(&mut **expr),
4314                HirScalarExpr::CallBinary {
4315                    func: _,
4316                    expr1,
4317                    expr2,
4318                    name: _name,
4319                } => {
4320                    // Push in reverse so children pop (and are visited) left-to-right.
4321                    worklist.push(&mut **expr2);
4322                    worklist.push(&mut **expr1);
4323                }
4324                HirScalarExpr::CallVariadic {
4325                    func: _,
4326                    exprs,
4327                    name: _name,
4328                } => {
4329                    worklist.extend(exprs.iter_mut().rev());
4330                }
4331                HirScalarExpr::If {
4332                    cond,
4333                    then,
4334                    els,
4335                    name: _,
4336                } => {
4337                    worklist.push(&mut **els);
4338                    worklist.push(&mut **then);
4339                    worklist.push(&mut **cond);
4340                }
4341                HirScalarExpr::Exists(hir, _) | HirScalarExpr::Select(hir, _) => {
4342                    subqueries.push(&mut **hir);
4343                }
4344                HirScalarExpr::Windowing(
4345                    WindowExpr {
4346                        func,
4347                        partition_by,
4348                        order_by,
4349                    },
4350                    _,
4351                ) => {
4352                    // Push in reverse so children pop (and are visited) left-to-right:
4353                    // func args, then partition_by, then order_by.
4354                    worklist.extend(order_by.iter_mut().rev());
4355                    worklist.extend(partition_by.iter_mut().rev());
4356                    match func {
4357                        WindowExprType::Scalar(_) => (),
4358                        WindowExprType::Value(val) => worklist.push(&mut val.args),
4359                        WindowExprType::Aggregate(agg) => {
4360                            worklist.push(&mut agg.aggregate_expr.expr)
4361                        }
4362                    }
4363                }
4364            }
4365        }
4366
4367        subqueries
4368    }
4369}
4370
4371/// Yields the direct scalar children of `self`. Stops at `Exists` / `Select`:
4372/// scalars inside subquery bodies are not surfaced. The asymmetry with
4373/// `VisitChildren<Self> for HirRelationExpr` (which does see through scalars
4374/// into subqueries) is what avoids the mutual recursion warned about on the
4375/// [`VisitChildren`] trait.
4376impl VisitChildren<Self> for HirScalarExpr {
4377    fn visit_children<F>(&self, mut f: F)
4378    where
4379        F: FnMut(&Self),
4380    {
4381        use HirScalarExpr::*;
4382        match self {
4383            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4384            CallUnary { expr, .. } => f(expr),
4385            CallBinary { expr1, expr2, .. } => {
4386                f(expr1);
4387                f(expr2);
4388            }
4389            CallVariadic { exprs, .. } => {
4390                for expr in exprs {
4391                    f(expr);
4392                }
4393            }
4394            If {
4395                cond,
4396                then,
4397                els,
4398                name: _,
4399            } => {
4400                f(cond);
4401                f(then);
4402                f(els);
4403            }
4404            Exists(..) | Select(..) => (),
4405            Windowing(expr, _name) => expr.visit_children(f),
4406        }
4407    }
4408
4409    fn visit_mut_children<F>(&mut self, mut f: F)
4410    where
4411        F: FnMut(&mut Self),
4412    {
4413        use HirScalarExpr::*;
4414        match self {
4415            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4416            CallUnary { expr, .. } => f(expr),
4417            CallBinary { expr1, expr2, .. } => {
4418                f(expr1);
4419                f(expr2);
4420            }
4421            CallVariadic { exprs, .. } => {
4422                for expr in exprs {
4423                    f(expr);
4424                }
4425            }
4426            If {
4427                cond,
4428                then,
4429                els,
4430                name: _,
4431            } => {
4432                f(cond);
4433                f(then);
4434                f(els);
4435            }
4436            Exists(..) | Select(..) => (),
4437            Windowing(expr, _name) => expr.visit_mut_children(f),
4438        }
4439    }
4440
4441    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
4442    where
4443        F: FnMut(&Self) -> Result<(), E>,
4444    {
4445        use HirScalarExpr::*;
4446        match self {
4447            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4448            CallUnary { expr, .. } => f(expr)?,
4449            CallBinary { expr1, expr2, .. } => {
4450                f(expr1)?;
4451                f(expr2)?;
4452            }
4453            CallVariadic { exprs, .. } => {
4454                for expr in exprs {
4455                    f(expr)?;
4456                }
4457            }
4458            If {
4459                cond,
4460                then,
4461                els,
4462                name: _,
4463            } => {
4464                f(cond)?;
4465                f(then)?;
4466                f(els)?;
4467            }
4468            Exists(..) | Select(..) => (),
4469            Windowing(expr, _name) => expr.try_visit_children(f)?,
4470        }
4471        Ok(())
4472    }
4473
4474    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
4475    where
4476        F: FnMut(&mut Self) -> Result<(), E>,
4477    {
4478        use HirScalarExpr::*;
4479        match self {
4480            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
4481            CallUnary { expr, .. } => f(expr)?,
4482            CallBinary { expr1, expr2, .. } => {
4483                f(expr1)?;
4484                f(expr2)?;
4485            }
4486            CallVariadic { exprs, .. } => {
4487                for expr in exprs {
4488                    f(expr)?;
4489                }
4490            }
4491            If {
4492                cond,
4493                then,
4494                els,
4495                name: _,
4496            } => {
4497                f(cond)?;
4498                f(then)?;
4499                f(els)?;
4500            }
4501            Exists(..) | Select(..) => (),
4502            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
4503        }
4504        Ok(())
4505    }
4506
4507    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a Self>
4508    where
4509        Self: 'a,
4510    {
4511        use HirScalarExpr::*;
4512        let v: Vec<&Self> = match self {
4513            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![],
4514            CallUnary { expr, .. } => vec![&*expr],
4515            CallBinary { expr1, expr2, .. } => {
4516                vec![&*expr1, &*expr2]
4517            }
4518            CallVariadic { exprs, .. } => exprs.iter().collect(),
4519            If {
4520                cond,
4521                then,
4522                els,
4523                name: _,
4524            } => {
4525                vec![&*cond, &*then, &*els]
4526            }
4527            Exists(..) | Select(..) => vec![],
4528            Windowing(expr, _name) => expr.children().collect(),
4529        };
4530        v.into_iter()
4531    }
4532
4533    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut Self>
4534    where
4535        Self: 'a,
4536    {
4537        use HirScalarExpr::*;
4538        let v: Vec<&mut Self> = match self {
4539            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![],
4540            CallUnary { expr, .. } => vec![&mut **expr],
4541            CallBinary { expr1, expr2, .. } => {
4542                vec![&mut **expr1, &mut **expr2]
4543            }
4544            CallVariadic { exprs, .. } => exprs.iter_mut().collect(),
4545            If {
4546                cond,
4547                then,
4548                els,
4549                name: _,
4550            } => {
4551                vec![&mut **cond, &mut **then, &mut **els]
4552            }
4553            Exists(..) | Select(..) => vec![],
4554            Windowing(expr, _name) => expr.children_mut().collect(),
4555        };
4556        v.into_iter()
4557    }
4558}
4559
4560/// Yields the immediate `HirRelationExpr` children of `self` (the bodies of
4561/// `Exists` / `Select`); does not descend into them.
4562impl VisitChildren<HirRelationExpr> for HirScalarExpr {
4563    fn visit_children<F>(&self, mut f: F)
4564    where
4565        F: FnMut(&HirRelationExpr),
4566    {
4567        use HirScalarExpr::*;
4568        match self {
4569            Column(..)
4570            | Parameter(..)
4571            | Literal(..)
4572            | CallUnmaterializable(..)
4573            | CallUnary { .. }
4574            | CallBinary { .. }
4575            | CallVariadic { .. }
4576            | If { .. }
4577            | Windowing(..) => (),
4578            Exists(expr, _name) | Select(expr, _name) => f(expr),
4579        }
4580    }
4581
4582    fn visit_mut_children<F>(&mut self, mut f: F)
4583    where
4584        F: FnMut(&mut HirRelationExpr),
4585    {
4586        use HirScalarExpr::*;
4587        match self {
4588            Column(..)
4589            | Parameter(..)
4590            | Literal(..)
4591            | CallUnmaterializable(..)
4592            | CallUnary { .. }
4593            | CallBinary { .. }
4594            | CallVariadic { .. }
4595            | If { .. }
4596            | Windowing(..) => (),
4597            Exists(expr, _name) | Select(expr, _name) => f(expr),
4598        }
4599    }
4600
4601    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
4602    where
4603        F: FnMut(&HirRelationExpr) -> Result<(), E>,
4604    {
4605        use HirScalarExpr::*;
4606        match self {
4607            Column(..)
4608            | Parameter(..)
4609            | Literal(..)
4610            | CallUnmaterializable(..)
4611            | CallUnary { .. }
4612            | CallBinary { .. }
4613            | CallVariadic { .. }
4614            | If { .. }
4615            | Windowing(..) => (),
4616            Exists(expr, _name) | Select(expr, _name) => f(expr)?,
4617        }
4618        Ok(())
4619    }
4620
4621    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
4622    where
4623        F: FnMut(&mut HirRelationExpr) -> Result<(), E>,
4624    {
4625        use HirScalarExpr::*;
4626        match self {
4627            Column(..)
4628            | Parameter(..)
4629            | Literal(..)
4630            | CallUnmaterializable(..)
4631            | CallUnary { .. }
4632            | CallBinary { .. }
4633            | CallVariadic { .. }
4634            | If { .. }
4635            | Windowing(..) => (),
4636            Exists(expr, _name) | Select(expr, _name) => f(expr)?,
4637        }
4638        Ok(())
4639    }
4640
4641    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a HirRelationExpr>
4642    where
4643        HirRelationExpr: 'a,
4644    {
4645        let mut child: Option<&HirRelationExpr> = None;
4646        use HirScalarExpr::*;
4647        match self {
4648            Column(..)
4649            | Parameter(..)
4650            | Literal(..)
4651            | CallUnmaterializable(..)
4652            | CallUnary { .. }
4653            | CallBinary { .. }
4654            | CallVariadic { .. }
4655            | If { .. }
4656            | Windowing(..) => (),
4657            Exists(expr, _name) | Select(expr, _name) => child = Some(&*expr),
4658        }
4659
4660        child.into_iter()
4661    }
4662
4663    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut HirRelationExpr>
4664    where
4665        HirRelationExpr: 'a,
4666    {
4667        let mut child: Option<&mut HirRelationExpr> = None;
4668        use HirScalarExpr::*;
4669        match self {
4670            Column(..)
4671            | Parameter(..)
4672            | Literal(..)
4673            | CallUnmaterializable(..)
4674            | CallUnary { .. }
4675            | CallBinary { .. }
4676            | CallVariadic { .. }
4677            | If { .. }
4678            | Windowing(..) => (),
4679            Exists(expr, _name) | Select(expr, _name) => child = Some(&mut **expr),
4680        }
4681
4682        child.into_iter()
4683    }
4684}
4685
4686impl AbstractExpr for HirScalarExpr {
4687    type Type = SqlColumnType;
4688
4689    fn typ(
4690        &self,
4691        outers: &[SqlRelationType],
4692        inner: &SqlRelationType,
4693        params: &BTreeMap<usize, SqlScalarType>,
4694    ) -> Self::Type {
4695        stack::maybe_grow(|| match self {
4696            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
4697                if *level == 0 {
4698                    inner.column_types[*column].clone()
4699                } else {
4700                    outers[*level - 1].column_types[*column].clone()
4701                }
4702            }
4703            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
4704            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
4705            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(),
4706            HirScalarExpr::CallUnary {
4707                expr,
4708                func,
4709                name: _,
4710            } => func.output_sql_type(expr.typ(outers, inner, params)),
4711            HirScalarExpr::CallBinary {
4712                expr1,
4713                expr2,
4714                func,
4715                name: _,
4716            } => func.output_sql_type(&[
4717                expr1.typ(outers, inner, params),
4718                expr2.typ(outers, inner, params),
4719            ]),
4720            HirScalarExpr::CallVariadic {
4721                exprs,
4722                func,
4723                name: _,
4724            } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
4725            HirScalarExpr::If {
4726                cond: _,
4727                then,
4728                els,
4729                name: _,
4730            } => {
4731                let then_type = then.typ(outers, inner, params);
4732                let else_type = els.typ(outers, inner, params);
4733                then_type.sql_union(&else_type).unwrap() // HIR deliberately not using `union`
4734            }
4735            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
4736            HirScalarExpr::Select(expr, _name) => {
4737                let mut outers = outers.to_vec();
4738                outers.insert(0, inner.clone());
4739                expr.typ(&outers, params)
4740                    .column_types
4741                    .into_element()
4742                    .nullable(true)
4743            }
4744            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
4745        })
4746    }
4747}
4748
4749impl AggregateExpr {
4750    pub fn typ(
4751        &self,
4752        outers: &[SqlRelationType],
4753        inner: &SqlRelationType,
4754        params: &BTreeMap<usize, SqlScalarType>,
4755    ) -> SqlColumnType {
4756        self.func
4757            .output_sql_type(self.expr.typ(outers, inner, params))
4758    }
4759
4760    /// Returns whether the expression is COUNT(*) or not.  Note that
4761    /// when we define the count builtin in sql::func, we convert
4762    /// COUNT(*) to COUNT(true), making it indistinguishable from
4763    /// literal COUNT(true), but we prefer to consider this as the
4764    /// former.
4765    ///
4766    /// (MIR has the same `is_count_asterisk`.)
4767    pub fn is_count_asterisk(&self) -> bool {
4768        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
4769    }
4770}