mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50    type Relation = HirRelationExpr;
51    type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56        if *all {
57            let rhs = rhs.negate();
58            HirRelationExpr::union(lhs, rhs).threshold()
59        } else {
60            let lhs = lhs.distinct();
61            let rhs = rhs.distinct().negate();
62            HirRelationExpr::union(lhs, rhs).threshold()
63        }
64    }
65
66    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67        let mut result = None;
68
69        use HirRelationExpr::*;
70        if let Threshold { input } = expr {
71            if let Union { base: lhs, inputs } = input.as_ref() {
72                if let [rhs] = &inputs[..] {
73                    if let Negate { input: rhs } = rhs {
74                        match (lhs.as_ref(), rhs.as_ref()) {
75                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
76                                let all = false;
77                                let lhs = lhs.as_ref();
78                                let rhs = rhs.as_ref();
79                                result = Some(Except { all, lhs, rhs })
80                            }
81                            (lhs, rhs) => {
82                                let all = true;
83                                result = Some(Except { all, lhs, rhs })
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        result
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
97pub enum HirRelationExpr {
98    Constant {
99        rows: Vec<Row>,
100        typ: RelationType,
101    },
102    Get {
103        id: mz_expr::Id,
104        typ: RelationType,
105    },
106    /// Mutually recursive CTE
107    LetRec {
108        /// Maximum number of iterations to evaluate. If None, then there is no limit.
109        limit: Option<LetRecLimit>,
110        /// List of bindings all of which are in scope of each other.
111        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, RelationType)>,
112        /// Result of the AST node.
113        body: Box<HirRelationExpr>,
114    },
115    /// CTE
116    Let {
117        name: String,
118        /// The identifier to be used in `Get` variants to retrieve `value`.
119        id: mz_expr::LocalId,
120        /// The collection to be bound to `name`.
121        value: Box<HirRelationExpr>,
122        /// The result of the `Let`, evaluated with `name` bound to `value`.
123        body: Box<HirRelationExpr>,
124    },
125    Project {
126        input: Box<HirRelationExpr>,
127        outputs: Vec<usize>,
128    },
129    Map {
130        input: Box<HirRelationExpr>,
131        scalars: Vec<HirScalarExpr>,
132    },
133    CallTable {
134        func: TableFunc,
135        exprs: Vec<HirScalarExpr>,
136    },
137    Filter {
138        input: Box<HirRelationExpr>,
139        predicates: Vec<HirScalarExpr>,
140    },
141    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
142    /// joins away into more primitive exprs
143    Join {
144        left: Box<HirRelationExpr>,
145        right: Box<HirRelationExpr>,
146        on: HirScalarExpr,
147        kind: JoinKind,
148    },
149    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
150    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
151    /// rows
152    Reduce {
153        input: Box<HirRelationExpr>,
154        group_key: Vec<usize>,
155        aggregates: Vec<AggregateExpr>,
156        expected_group_size: Option<u64>,
157    },
158    Distinct {
159        input: Box<HirRelationExpr>,
160    },
161    /// Groups and orders within each group, limiting output.
162    TopK {
163        /// The source collection.
164        input: Box<HirRelationExpr>,
165        /// Column indices used to form groups.
166        group_key: Vec<usize>,
167        /// Column indices used to order rows within groups.
168        order_key: Vec<ColumnOrder>,
169        /// Number of records to retain.
170        /// It is of ScalarType::Int64.
171        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
172        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
173        /// is better for Postgres compat. This is because if there is a $1 here, then when external
174        /// tools `describe` the prepared statement, they discover this type. If what they find
175        /// were UInt64, then they might have trouble calling the prepared statement, because the
176        /// unsigned types are non-standard, and also don't exist even in Postgres.)
177        limit: Option<HirScalarExpr>,
178        /// Number of records to skip.
179        /// It is of ScalarType::Int64.
180        /// This can contain parameters at first, but by the time we reach lowering, this should
181        /// already be simply a Literal.
182        offset: HirScalarExpr,
183        /// User-supplied hint: how many rows will have the same group key.
184        expected_group_size: Option<u64>,
185    },
186    Negate {
187        input: Box<HirRelationExpr>,
188    },
189    /// Keep rows from a dataflow where the row counts are positive.
190    Threshold {
191        input: Box<HirRelationExpr>,
192    },
193    Union {
194        base: Box<HirRelationExpr>,
195        inputs: Vec<HirRelationExpr>,
196    },
197}
198
199/// Stored column metadata.
200pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
204pub enum HirScalarExpr {
205    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
206    /// variable could refer to a column of the current input, or to a column of an outer relation.
207    /// We use ColumnRef to denote the difference.
208    Column(ColumnRef, NameMetadata),
209    Parameter(usize, NameMetadata),
210    Literal(Row, ColumnType, NameMetadata),
211    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212    CallUnary {
213        func: UnaryFunc,
214        expr: Box<HirScalarExpr>,
215        name: NameMetadata,
216    },
217    CallBinary {
218        func: BinaryFunc,
219        expr1: Box<HirScalarExpr>,
220        expr2: Box<HirScalarExpr>,
221        name: NameMetadata,
222    },
223    CallVariadic {
224        func: VariadicFunc,
225        exprs: Vec<HirScalarExpr>,
226        name: NameMetadata,
227    },
228    If {
229        cond: Box<HirScalarExpr>,
230        then: Box<HirScalarExpr>,
231        els: Box<HirScalarExpr>,
232        name: NameMetadata,
233    },
234    /// Returns true if `expr` returns any rows
235    Exists(Box<HirRelationExpr>, NameMetadata),
236    /// Given `expr` with arity 1. If expr returns:
237    /// * 0 rows, return NULL
238    /// * 1 row, return the value of that row
239    /// * >1 rows, we return an error
240    Select(Box<HirRelationExpr>, NameMetadata),
241    Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245/// Represents the invocation of a window function over an optional partitioning with an optional
246/// order.
247pub struct WindowExpr {
248    pub func: WindowExprType,
249    pub partition_by: Vec<HirScalarExpr>,
250    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
251    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
252    ///    above,
253    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
254    /// These are separated because they are used in different places: the outer `order_by` is used
255    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
256    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
257    /// (`WindowExpr` exists only in HIR, but not in MIR.)
258    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
259    /// lowering, and not to original input columns.
260    pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265    where
266        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267    {
268        #[allow(deprecated)]
269        self.func.visit_expressions(f)?;
270        for expr in self.partition_by.iter() {
271            f(expr)?;
272        }
273        for expr in self.order_by.iter() {
274            f(expr)?;
275        }
276        Ok(())
277    }
278
279    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280    where
281        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282    {
283        #[allow(deprecated)]
284        self.func.visit_expressions_mut(f)?;
285        for expr in self.partition_by.iter_mut() {
286            f(expr)?;
287        }
288        for expr in self.order_by.iter_mut() {
289            f(expr)?;
290        }
291        Ok(())
292    }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296    fn visit_children<F>(&self, mut f: F)
297    where
298        F: FnMut(&HirScalarExpr),
299    {
300        self.func.visit_children(&mut f);
301        for expr in self.partition_by.iter() {
302            f(expr);
303        }
304        for expr in self.order_by.iter() {
305            f(expr);
306        }
307    }
308
309    fn visit_mut_children<F>(&mut self, mut f: F)
310    where
311        F: FnMut(&mut HirScalarExpr),
312    {
313        self.func.visit_mut_children(&mut f);
314        for expr in self.partition_by.iter_mut() {
315            f(expr);
316        }
317        for expr in self.order_by.iter_mut() {
318            f(expr);
319        }
320    }
321
322    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323    where
324        F: FnMut(&HirScalarExpr) -> Result<(), E>,
325        E: From<RecursionLimitError>,
326    {
327        self.func.try_visit_children(&mut f)?;
328        for expr in self.partition_by.iter() {
329            f(expr)?;
330        }
331        for expr in self.order_by.iter() {
332            f(expr)?;
333        }
334        Ok(())
335    }
336
337    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338    where
339        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340        E: From<RecursionLimitError>,
341    {
342        self.func.try_visit_mut_children(&mut f)?;
343        for expr in self.partition_by.iter_mut() {
344            f(expr)?;
345        }
346        for expr in self.order_by.iter_mut() {
347            f(expr)?;
348        }
349        Ok(())
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354/// A window function with its parameters.
355///
356/// There are three types of window functions:
357/// - scalar window functions, which return a different scalar value for each
358///   row within a partition that depends exclusively on the position of the row
359///   within the partition;
360/// - value window functions, which return a scalar value for each row within a
361///   partition that might be computed based on a single row, which is usually not
362///   the current row (e.g., previous or following row; first or last row of the
363///   partition);
364/// - aggregate window functions, which compute a traditional aggregation as a
365///   window function (e.g. `sum(x) OVER (...)`).
366///   (Aggregate window  functions can in some cases be computed by joining the
367///   input relation with a reduction over the same relation that computes the
368///   aggregation using the partition key as its grouping key, but we don't
369///   automatically do this currently.)
370pub enum WindowExprType {
371    Scalar(ScalarWindowExpr),
372    Value(ValueWindowExpr),
373    Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379    where
380        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381    {
382        #[allow(deprecated)]
383        match self {
384            Self::Scalar(expr) => expr.visit_expressions(f),
385            Self::Value(expr) => expr.visit_expressions(f),
386            Self::Aggregate(expr) => expr.visit_expressions(f),
387        }
388    }
389
390    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392    where
393        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394    {
395        #[allow(deprecated)]
396        match self {
397            Self::Scalar(expr) => expr.visit_expressions_mut(f),
398            Self::Value(expr) => expr.visit_expressions_mut(f),
399            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400        }
401    }
402
403    fn typ(
404        &self,
405        outers: &[RelationType],
406        inner: &RelationType,
407        params: &BTreeMap<usize, ScalarType>,
408    ) -> ColumnType {
409        match self {
410            Self::Scalar(expr) => expr.typ(outers, inner, params),
411            Self::Value(expr) => expr.typ(outers, inner, params),
412            Self::Aggregate(expr) => expr.typ(outers, inner, params),
413        }
414    }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418    fn visit_children<F>(&self, f: F)
419    where
420        F: FnMut(&HirScalarExpr),
421    {
422        match self {
423            Self::Scalar(_) => (),
424            Self::Value(expr) => expr.visit_children(f),
425            Self::Aggregate(expr) => expr.visit_children(f),
426        }
427    }
428
429    fn visit_mut_children<F>(&mut self, f: F)
430    where
431        F: FnMut(&mut HirScalarExpr),
432    {
433        match self {
434            Self::Scalar(_) => (),
435            Self::Value(expr) => expr.visit_mut_children(f),
436            Self::Aggregate(expr) => expr.visit_mut_children(f),
437        }
438    }
439
440    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441    where
442        F: FnMut(&HirScalarExpr) -> Result<(), E>,
443        E: From<RecursionLimitError>,
444    {
445        match self {
446            Self::Scalar(_) => Ok(()),
447            Self::Value(expr) => expr.try_visit_children(f),
448            Self::Aggregate(expr) => expr.try_visit_children(f),
449        }
450    }
451
452    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453    where
454        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455        E: From<RecursionLimitError>,
456    {
457        match self {
458            Self::Scalar(_) => Ok(()),
459            Self::Value(expr) => expr.try_visit_mut_children(f),
460            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461        }
462    }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467    pub func: ScalarWindowFunc,
468    pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474    where
475        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476    {
477        match self.func {
478            ScalarWindowFunc::RowNumber => {}
479            ScalarWindowFunc::Rank => {}
480            ScalarWindowFunc::DenseRank => {}
481        }
482        Ok(())
483    }
484
485    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487    where
488        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489    {
490        match self.func {
491            ScalarWindowFunc::RowNumber => {}
492            ScalarWindowFunc::Rank => {}
493            ScalarWindowFunc::DenseRank => {}
494        }
495        Ok(())
496    }
497
498    fn typ(
499        &self,
500        _outers: &[RelationType],
501        _inner: &RelationType,
502        _params: &BTreeMap<usize, ScalarType>,
503    ) -> ColumnType {
504        self.func.output_type()
505    }
506
507    pub fn into_expr(self) -> mz_expr::AggregateFunc {
508        match self.func {
509            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510                order_by: self.order_by,
511            },
512            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513                order_by: self.order_by,
514            },
515            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516                order_by: self.order_by,
517            },
518        }
519    }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523/// Scalar Window functions
524pub enum ScalarWindowFunc {
525    RowNumber,
526    Rank,
527    DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532        match self {
533            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534            ScalarWindowFunc::Rank => write!(f, "rank"),
535            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536        }
537    }
538}
539
540impl ScalarWindowFunc {
541    pub fn output_type(&self) -> ColumnType {
542        match self {
543            ScalarWindowFunc::RowNumber => ScalarType::Int64.nullable(false),
544            ScalarWindowFunc::Rank => ScalarType::Int64.nullable(false),
545            ScalarWindowFunc::DenseRank => ScalarType::Int64.nullable(false),
546        }
547    }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552    pub func: ValueWindowFunc,
553    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
554    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
555    /// e.g., `row(#1, 3, null)`.
556    /// If it's a fused window function, then the arguments of each of the constituent function
557    /// calls are wrapped in an outer record.
558    pub args: Box<HirScalarExpr>,
559    /// See comment on `WindowExpr::order_by`.
560    pub order_by: Vec<ColumnOrder>,
561    pub window_frame: WindowFrame,
562    pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567        match self {
568            ValueWindowFunc::Lag => write!(f, "lag"),
569            ValueWindowFunc::Lead => write!(f, "lead"),
570            ValueWindowFunc::FirstValue => write!(f, "first_value"),
571            ValueWindowFunc::LastValue => write!(f, "last_value"),
572            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573        }
574    }
575}
576
577impl ValueWindowExpr {
578    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580    where
581        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582    {
583        f(&self.args)
584    }
585
586    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588    where
589        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590    {
591        f(&mut self.args)
592    }
593
594    fn typ(
595        &self,
596        outers: &[RelationType],
597        inner: &RelationType,
598        params: &BTreeMap<usize, ScalarType>,
599    ) -> ColumnType {
600        self.func.output_type(self.args.typ(outers, inner, params))
601    }
602
603    /// Converts into `mz_expr::AggregateFunc`.
604    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605        (
606            self.args,
607            self.func
608                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609        )
610    }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614    fn visit_children<F>(&self, mut f: F)
615    where
616        F: FnMut(&HirScalarExpr),
617    {
618        f(&self.args)
619    }
620
621    fn visit_mut_children<F>(&mut self, mut f: F)
622    where
623        F: FnMut(&mut HirScalarExpr),
624    {
625        f(&mut self.args)
626    }
627
628    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629    where
630        F: FnMut(&HirScalarExpr) -> Result<(), E>,
631        E: From<RecursionLimitError>,
632    {
633        f(&self.args)
634    }
635
636    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637    where
638        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639        E: From<RecursionLimitError>,
640    {
641        f(&mut self.args)
642    }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646/// Value Window functions
647pub enum ValueWindowFunc {
648    Lag,
649    Lead,
650    FirstValue,
651    LastValue,
652    Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656    pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
657        match self {
658            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659                // The input is a (value, offset, default) record, so extract the type of the first arg
660                input_type.scalar_type.unwrap_record_element_type()[0]
661                    .clone()
662                    .nullable(true)
663            }
664            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665                input_type.scalar_type.nullable(true)
666            }
667            ValueWindowFunc::Fused(funcs) => {
668                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669                ScalarType::Record {
670                    fields: funcs
671                        .iter()
672                        .zip_eq(input_types)
673                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674                        .collect(),
675                    custom_id: None,
676                }
677                .nullable(false)
678            }
679        }
680    }
681
682    pub fn into_expr(
683        self,
684        order_by: Vec<ColumnOrder>,
685        window_frame: WindowFrame,
686        ignore_nulls: bool,
687    ) -> mz_expr::AggregateFunc {
688        match self {
689            // Lag and Lead are fundamentally the same function, just with opposite directions
690            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691                order_by,
692                lag_lead: mz_expr::LagLeadType::Lag,
693                ignore_nulls,
694            },
695            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696                order_by,
697                lag_lead: mz_expr::LagLeadType::Lead,
698                ignore_nulls,
699            },
700            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701                order_by,
702                window_frame,
703            },
704            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705                order_by,
706                window_frame,
707            },
708            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709                funcs: funcs
710                    .into_iter()
711                    .map(|func| {
712                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713                    })
714                    .collect(),
715                order_by,
716            },
717        }
718    }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723    pub aggregate_expr: AggregateExpr,
724    pub order_by: Vec<ColumnOrder>,
725    pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731    where
732        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733    {
734        f(&self.aggregate_expr.expr)
735    }
736
737    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739    where
740        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741    {
742        f(&mut self.aggregate_expr.expr)
743    }
744
745    fn typ(
746        &self,
747        outers: &[RelationType],
748        inner: &RelationType,
749        params: &BTreeMap<usize, ScalarType>,
750    ) -> ColumnType {
751        self.aggregate_expr
752            .func
753            .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754    }
755
756    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758            (
759                self.aggregate_expr.expr,
760                FusedWindowAggregate {
761                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762                    order_by: self.order_by,
763                    window_frame: self.window_frame,
764                },
765            )
766        } else {
767            (
768                self.aggregate_expr.expr,
769                WindowAggregate {
770                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771                    order_by: self.order_by,
772                    window_frame: self.window_frame,
773                },
774            )
775        }
776    }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780    fn visit_children<F>(&self, mut f: F)
781    where
782        F: FnMut(&HirScalarExpr),
783    {
784        f(&self.aggregate_expr.expr)
785    }
786
787    fn visit_mut_children<F>(&mut self, mut f: F)
788    where
789        F: FnMut(&mut HirScalarExpr),
790    {
791        f(&mut self.aggregate_expr.expr)
792    }
793
794    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795    where
796        F: FnMut(&HirScalarExpr) -> Result<(), E>,
797        E: From<RecursionLimitError>,
798    {
799        f(&self.aggregate_expr.expr)
800    }
801
802    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803    where
804        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805        E: From<RecursionLimitError>,
806    {
807        f(&mut self.aggregate_expr.expr)
808    }
809}
810
811/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
812/// determined. Several SQL expressions can be freely coerced based upon where
813/// in the expression tree they appear. For example, the string literal '42'
814/// will be automatically coerced to the integer 42 if used in a numeric
815/// context:
816///
817/// ```sql
818/// SELECT '42' + 42
819/// ```
820///
821/// This separate type gives the code that needs to interact with coercions very
822/// fine-grained control over what coercions happen and when.
823///
824/// The primary driver of coercion is function and operator selection, as
825/// choosing the correct function or operator implementation depends on the type
826/// of the provided arguments. Coercion also occurs at the very root of the
827/// scalar expression tree. For example in
828///
829/// ```sql
830/// SELECT ... WHERE $1
831/// ```
832///
833/// the `WHERE` clause will coerce the contained unconstrained type parameter
834/// `$1` to have type bool.
835#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837    Coerced(HirScalarExpr),
838    Parameter(usize),
839    LiteralNull,
840    LiteralString(String),
841    LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845    pub fn type_as(self, ecx: &ExprContext, ty: &ScalarType) -> Result<HirScalarExpr, PlanError> {
846        let expr = typeconv::plan_coerce(ecx, self, ty)?;
847        let expr_ty = ecx.scalar_type(&expr);
848        if ty != &expr_ty {
849            sql_bail!(
850                "{} must have type {}, not type {}",
851                ecx.name,
852                ecx.humanize_scalar_type(ty, false),
853                ecx.humanize_scalar_type(&expr_ty, false),
854            );
855        }
856        Ok(expr)
857    }
858
859    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
860        typeconv::plan_coerce(ecx, self, &ScalarType::String)
861    }
862
863    pub fn cast_to(
864        self,
865        ecx: &ExprContext,
866        ccx: CastContext,
867        ty: &ScalarType,
868    ) -> Result<HirScalarExpr, PlanError> {
869        let expr = typeconv::plan_coerce(ecx, self, ty)?;
870        typeconv::plan_cast(ecx, ccx, expr, ty)
871    }
872}
873
874/// The column type for a [`CoercibleScalarExpr`].
875#[derive(Clone, Debug)]
876pub enum CoercibleColumnType {
877    Coerced(ColumnType),
878    Record(Vec<CoercibleColumnType>),
879    Uncoerced,
880}
881
882impl CoercibleColumnType {
883    /// Reports the nullability of the type.
884    pub fn nullable(&self) -> bool {
885        match self {
886            // A coerced value's nullability is known.
887            CoercibleColumnType::Coerced(ct) => ct.nullable,
888
889            // A literal record can never be null.
890            CoercibleColumnType::Record(_) => false,
891
892            // An uncoerced literal may be the literal `NULL`, so we have
893            // to conservatively assume it is nullable.
894            CoercibleColumnType::Uncoerced => true,
895        }
896    }
897}
898
899/// The scalar type for a [`CoercibleScalarExpr`].
900#[derive(Clone, Debug)]
901pub enum CoercibleScalarType {
902    Coerced(ScalarType),
903    Record(Vec<CoercibleColumnType>),
904    Uncoerced,
905}
906
907impl CoercibleScalarType {
908    /// Reports whether the scalar type has been coerced.
909    pub fn is_coerced(&self) -> bool {
910        matches!(self, CoercibleScalarType::Coerced(_))
911    }
912
913    /// Returns the coerced scalar type, if the type is coerced.
914    pub fn as_coerced(&self) -> Option<&ScalarType> {
915        match self {
916            CoercibleScalarType::Coerced(t) => Some(t),
917            _ => None,
918        }
919    }
920
921    /// If the type is coerced, apply the mapping function to the contained
922    /// scalar type.
923    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
924    where
925        F: FnOnce(ScalarType) -> ScalarType,
926    {
927        match self {
928            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
929            _ => self,
930        }
931    }
932
933    /// If the type is an coercible record, forcibly converts to a coerced
934    /// record type. Any uncoerced field types are assumed to be of type text.
935    ///
936    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
937    /// accepts a type hint that can indicate the types of uncoerced field
938    /// types.
939    pub fn force_coerced_if_record(&mut self) {
940        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> ScalarType {
941            let mut fields = vec![];
942            for (i, uf) in uncoerced_fields.enumerate() {
943                let name = ColumnName::from(format!("f{}", i + 1));
944                let ty = match uf {
945                    CoercibleColumnType::Coerced(ty) => ty,
946                    CoercibleColumnType::Record(mut fields) => {
947                        convert(fields.drain(..)).nullable(false)
948                    }
949                    CoercibleColumnType::Uncoerced => ScalarType::String.nullable(true),
950                };
951                fields.push((name, ty))
952            }
953            ScalarType::Record {
954                fields: fields.into(),
955                custom_id: None,
956            }
957        }
958
959        if let CoercibleScalarType::Record(fields) = self {
960            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
961        }
962    }
963}
964
965/// An expression whose type can be ascertained.
966///
967/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
968pub trait AbstractExpr {
969    type Type: AbstractColumnType;
970
971    /// Computes the type of the expression.
972    fn typ(
973        &self,
974        outers: &[RelationType],
975        inner: &RelationType,
976        params: &BTreeMap<usize, ScalarType>,
977    ) -> Self::Type;
978}
979
980impl AbstractExpr for CoercibleScalarExpr {
981    type Type = CoercibleColumnType;
982
983    fn typ(
984        &self,
985        outers: &[RelationType],
986        inner: &RelationType,
987        params: &BTreeMap<usize, ScalarType>,
988    ) -> Self::Type {
989        match self {
990            CoercibleScalarExpr::Coerced(expr) => {
991                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
992            }
993            CoercibleScalarExpr::LiteralRecord(scalars) => {
994                let fields = scalars
995                    .iter()
996                    .map(|s| s.typ(outers, inner, params))
997                    .collect();
998                CoercibleColumnType::Record(fields)
999            }
1000            _ => CoercibleColumnType::Uncoerced,
1001        }
1002    }
1003}
1004
1005/// A column type-like object whose underlying scalar type-like object can be
1006/// ascertained.
1007///
1008/// Abstracts over `ColumnType` and `CoercibleColumnType`.
1009pub trait AbstractColumnType {
1010    type AbstractScalarType;
1011
1012    /// Converts the column type-like object into its inner scalar type-like
1013    /// object.
1014    fn scalar_type(self) -> Self::AbstractScalarType;
1015}
1016
1017impl AbstractColumnType for ColumnType {
1018    type AbstractScalarType = ScalarType;
1019
1020    fn scalar_type(self) -> Self::AbstractScalarType {
1021        self.scalar_type
1022    }
1023}
1024
1025impl AbstractColumnType for CoercibleColumnType {
1026    type AbstractScalarType = CoercibleScalarType;
1027
1028    fn scalar_type(self) -> Self::AbstractScalarType {
1029        match self {
1030            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1031            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1032            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1033        }
1034    }
1035}
1036
1037impl From<HirScalarExpr> for CoercibleScalarExpr {
1038    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1039        CoercibleScalarExpr::Coerced(expr)
1040    }
1041}
1042
1043/// A leveled column reference.
1044///
1045/// In the course of decorrelation, multiple levels of nested subqueries are
1046/// traversed, and references to columns may correspond to different levels
1047/// of containing outer subqueries.
1048///
1049/// A `ColumnRef` allows expressions to refer to columns while being clear
1050/// about which level the column references without manually performing the
1051/// bookkeeping tracking their actual column locations.
1052///
1053/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1054/// from the reference, using `column` as a unique identifier in that subquery level.
1055/// A `level` of zero corresponds to the current scope, and levels increase to
1056/// indicate subqueries further "outwards".
1057#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1058pub struct ColumnRef {
1059    // scope level, where 0 is the current scope and 1+ are outer scopes.
1060    pub level: usize,
1061    // level-local column identifier used.
1062    pub column: usize,
1063}
1064
1065#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1066pub enum JoinKind {
1067    Inner,
1068    LeftOuter,
1069    RightOuter,
1070    FullOuter,
1071}
1072
1073impl fmt::Display for JoinKind {
1074    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1075        write!(
1076            f,
1077            "{}",
1078            match self {
1079                JoinKind::Inner => "Inner",
1080                JoinKind::LeftOuter => "LeftOuter",
1081                JoinKind::RightOuter => "RightOuter",
1082                JoinKind::FullOuter => "FullOuter",
1083            }
1084        )
1085    }
1086}
1087
1088impl JoinKind {
1089    pub fn can_be_correlated(&self) -> bool {
1090        match self {
1091            JoinKind::Inner | JoinKind::LeftOuter => true,
1092            JoinKind::RightOuter | JoinKind::FullOuter => false,
1093        }
1094    }
1095}
1096
1097#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1098pub struct AggregateExpr {
1099    pub func: AggregateFunc,
1100    pub expr: Box<HirScalarExpr>,
1101    pub distinct: bool,
1102}
1103
1104/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1105/// types may be different.
1106///
1107/// Specifically, the nullability of the aggregate columns is more common
1108/// here than in `expr`, as these aggregates may be applied over empty
1109/// result sets and should be null in those cases, whereas `expr` variants
1110/// only return null values when supplied nulls as input.
1111#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1112pub enum AggregateFunc {
1113    MaxNumeric,
1114    MaxInt16,
1115    MaxInt32,
1116    MaxInt64,
1117    MaxUInt16,
1118    MaxUInt32,
1119    MaxUInt64,
1120    MaxMzTimestamp,
1121    MaxFloat32,
1122    MaxFloat64,
1123    MaxBool,
1124    MaxString,
1125    MaxDate,
1126    MaxTimestamp,
1127    MaxTimestampTz,
1128    MaxInterval,
1129    MaxTime,
1130    MinNumeric,
1131    MinInt16,
1132    MinInt32,
1133    MinInt64,
1134    MinUInt16,
1135    MinUInt32,
1136    MinUInt64,
1137    MinMzTimestamp,
1138    MinFloat32,
1139    MinFloat64,
1140    MinBool,
1141    MinString,
1142    MinDate,
1143    MinTimestamp,
1144    MinTimestampTz,
1145    MinInterval,
1146    MinTime,
1147    SumInt16,
1148    SumInt32,
1149    SumInt64,
1150    SumUInt16,
1151    SumUInt32,
1152    SumUInt64,
1153    SumFloat32,
1154    SumFloat64,
1155    SumNumeric,
1156    Count,
1157    Any,
1158    All,
1159    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1160    /// into a JSON list. The other elements are columns used by `order_by`.
1161    ///
1162    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1163    /// layer, this function filters out `Datum::Null`, for consistency with
1164    /// the other aggregate functions.
1165    JsonbAgg {
1166        order_by: Vec<ColumnOrder>,
1167    },
1168    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1169    /// JSON map. The other elements are columns used by `order_by`.
1170    JsonbObjectAgg {
1171        order_by: Vec<ColumnOrder>,
1172    },
1173    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1174    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1175    /// elements are columns used by `order_by`.
1176    MapAgg {
1177        order_by: Vec<ColumnOrder>,
1178        value_type: ScalarType,
1179    },
1180    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1181    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1182    ArrayConcat {
1183        order_by: Vec<ColumnOrder>,
1184    },
1185    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1186    /// single `Datum::List`. The other elements are columns used by `order_by`.
1187    ListConcat {
1188        order_by: Vec<ColumnOrder>,
1189    },
1190    StringAgg {
1191        order_by: Vec<ColumnOrder>,
1192    },
1193    /// A bundle of fused window aggregations: its input is a record, whose each
1194    /// component will be the input to one of the `AggregateFunc`s.
1195    ///
1196    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1197    /// more specifically an `AggregateWindowExpr`.
1198    FusedWindowAgg {
1199        funcs: Vec<AggregateFunc>,
1200    },
1201    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1202    ///
1203    /// Useful for removing an expensive aggregation while maintaining the shape
1204    /// of a reduce operator.
1205    Dummy,
1206}
1207
1208impl AggregateFunc {
1209    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1210    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1211        match self {
1212            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1213            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1214            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1215            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1216            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1217            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1218            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1219            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1220            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1221            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1222            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1223            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1224            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1225            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1226            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1227            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1228            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1229            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1230            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1231            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1232            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1233            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1234            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1235            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1236            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1237            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1238            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1239            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1240            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1241            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1242            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1243            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1244            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1245            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1246            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1247            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1248            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1249            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1250            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1251            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1252            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1253            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1254            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1255            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1256            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1257            AggregateFunc::All => mz_expr::AggregateFunc::All,
1258            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1259            AggregateFunc::JsonbObjectAgg { order_by } => {
1260                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1261            }
1262            AggregateFunc::MapAgg {
1263                order_by,
1264                value_type,
1265            } => mz_expr::AggregateFunc::MapAgg {
1266                order_by,
1267                value_type,
1268            },
1269            AggregateFunc::ArrayConcat { order_by } => {
1270                mz_expr::AggregateFunc::ArrayConcat { order_by }
1271            }
1272            AggregateFunc::ListConcat { order_by } => {
1273                mz_expr::AggregateFunc::ListConcat { order_by }
1274            }
1275            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1276            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1277            // `AggregateWindowExpr::into_expr`.
1278            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1279                panic!("into_expr called on FusedWindowAgg")
1280            }
1281            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1282        }
1283    }
1284
1285    /// Returns a datum whose inclusion in the aggregation will not change its
1286    /// result.
1287    ///
1288    /// # Panics
1289    ///
1290    /// Panics if called on a `FusedWindowAgg`.
1291    pub fn identity_datum(&self) -> Datum<'static> {
1292        match self {
1293            AggregateFunc::Any => Datum::False,
1294            AggregateFunc::All => Datum::True,
1295            AggregateFunc::Dummy => Datum::Dummy,
1296            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1297            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1298            AggregateFunc::MaxNumeric
1299            | AggregateFunc::MaxInt16
1300            | AggregateFunc::MaxInt32
1301            | AggregateFunc::MaxInt64
1302            | AggregateFunc::MaxUInt16
1303            | AggregateFunc::MaxUInt32
1304            | AggregateFunc::MaxUInt64
1305            | AggregateFunc::MaxMzTimestamp
1306            | AggregateFunc::MaxFloat32
1307            | AggregateFunc::MaxFloat64
1308            | AggregateFunc::MaxBool
1309            | AggregateFunc::MaxString
1310            | AggregateFunc::MaxDate
1311            | AggregateFunc::MaxTimestamp
1312            | AggregateFunc::MaxTimestampTz
1313            | AggregateFunc::MaxInterval
1314            | AggregateFunc::MaxTime
1315            | AggregateFunc::MinNumeric
1316            | AggregateFunc::MinInt16
1317            | AggregateFunc::MinInt32
1318            | AggregateFunc::MinInt64
1319            | AggregateFunc::MinUInt16
1320            | AggregateFunc::MinUInt32
1321            | AggregateFunc::MinUInt64
1322            | AggregateFunc::MinMzTimestamp
1323            | AggregateFunc::MinFloat32
1324            | AggregateFunc::MinFloat64
1325            | AggregateFunc::MinBool
1326            | AggregateFunc::MinString
1327            | AggregateFunc::MinDate
1328            | AggregateFunc::MinTimestamp
1329            | AggregateFunc::MinTimestampTz
1330            | AggregateFunc::MinInterval
1331            | AggregateFunc::MinTime
1332            | AggregateFunc::SumInt16
1333            | AggregateFunc::SumInt32
1334            | AggregateFunc::SumInt64
1335            | AggregateFunc::SumUInt16
1336            | AggregateFunc::SumUInt32
1337            | AggregateFunc::SumUInt64
1338            | AggregateFunc::SumFloat32
1339            | AggregateFunc::SumFloat64
1340            | AggregateFunc::SumNumeric
1341            | AggregateFunc::Count
1342            | AggregateFunc::JsonbAgg { .. }
1343            | AggregateFunc::JsonbObjectAgg { .. }
1344            | AggregateFunc::MapAgg { .. }
1345            | AggregateFunc::StringAgg { .. } => Datum::Null,
1346            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1347                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1348                // in HIR planning, because it is introduced only during HIR transformation.
1349                //
1350                // The implementation could be something like the following, except that we need to
1351                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1352                // ```
1353                // let temp_storage = RowArena::new();
1354                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1355                // ```
1356                panic!("FusedWindowAgg doesn't have an identity_datum")
1357            }
1358        }
1359    }
1360
1361    /// The output column type for the result of an aggregation.
1362    ///
1363    /// The output column type also contains nullability information, which
1364    /// is (without further information) true for aggregations that are not
1365    /// counts.
1366    pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
1367        let scalar_type = match self {
1368            AggregateFunc::Count => ScalarType::Int64,
1369            AggregateFunc::Any => ScalarType::Bool,
1370            AggregateFunc::All => ScalarType::Bool,
1371            AggregateFunc::JsonbAgg { .. } => ScalarType::Jsonb,
1372            AggregateFunc::JsonbObjectAgg { .. } => ScalarType::Jsonb,
1373            AggregateFunc::StringAgg { .. } => ScalarType::String,
1374            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => ScalarType::Int64,
1375            AggregateFunc::SumInt64 => ScalarType::Numeric {
1376                max_scale: Some(NumericMaxScale::ZERO),
1377            },
1378            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => ScalarType::UInt64,
1379            AggregateFunc::SumUInt64 => ScalarType::Numeric {
1380                max_scale: Some(NumericMaxScale::ZERO),
1381            },
1382            AggregateFunc::MapAgg { value_type, .. } => ScalarType::Map {
1383                value_type: Box::new(value_type.clone()),
1384                custom_id: None,
1385            },
1386            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1387                match input_type.scalar_type {
1388                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1389                    ScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1390                    _ => unreachable!(),
1391                }
1392            }
1393            AggregateFunc::MaxNumeric
1394            | AggregateFunc::MaxInt16
1395            | AggregateFunc::MaxInt32
1396            | AggregateFunc::MaxInt64
1397            | AggregateFunc::MaxUInt16
1398            | AggregateFunc::MaxUInt32
1399            | AggregateFunc::MaxUInt64
1400            | AggregateFunc::MaxMzTimestamp
1401            | AggregateFunc::MaxFloat32
1402            | AggregateFunc::MaxFloat64
1403            | AggregateFunc::MaxBool
1404            | AggregateFunc::MaxString
1405            | AggregateFunc::MaxDate
1406            | AggregateFunc::MaxTimestamp
1407            | AggregateFunc::MaxTimestampTz
1408            | AggregateFunc::MaxInterval
1409            | AggregateFunc::MaxTime
1410            | AggregateFunc::MinNumeric
1411            | AggregateFunc::MinInt16
1412            | AggregateFunc::MinInt32
1413            | AggregateFunc::MinInt64
1414            | AggregateFunc::MinUInt16
1415            | AggregateFunc::MinUInt32
1416            | AggregateFunc::MinUInt64
1417            | AggregateFunc::MinMzTimestamp
1418            | AggregateFunc::MinFloat32
1419            | AggregateFunc::MinFloat64
1420            | AggregateFunc::MinBool
1421            | AggregateFunc::MinString
1422            | AggregateFunc::MinDate
1423            | AggregateFunc::MinTimestamp
1424            | AggregateFunc::MinTimestampTz
1425            | AggregateFunc::MinInterval
1426            | AggregateFunc::MinTime
1427            | AggregateFunc::SumFloat32
1428            | AggregateFunc::SumFloat64
1429            | AggregateFunc::SumNumeric
1430            | AggregateFunc::Dummy => input_type.scalar_type,
1431            AggregateFunc::FusedWindowAgg { funcs } => {
1432                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1433                ScalarType::Record {
1434                    fields: funcs
1435                        .iter()
1436                        .zip_eq(input_types)
1437                        .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1438                        .collect(),
1439                    custom_id: None,
1440                }
1441            }
1442        };
1443        // max/min/sum return null on empty sets
1444        let nullable = !matches!(self, AggregateFunc::Count);
1445        scalar_type.nullable(nullable)
1446    }
1447
1448    pub fn is_order_sensitive(&self) -> bool {
1449        use AggregateFunc::*;
1450        matches!(
1451            self,
1452            JsonbAgg { .. }
1453                | JsonbObjectAgg { .. }
1454                | MapAgg { .. }
1455                | ArrayConcat { .. }
1456                | ListConcat { .. }
1457                | StringAgg { .. }
1458        )
1459    }
1460}
1461
1462impl HirRelationExpr {
1463    pub fn typ(
1464        &self,
1465        outers: &[RelationType],
1466        params: &BTreeMap<usize, ScalarType>,
1467    ) -> RelationType {
1468        stack::maybe_grow(|| match self {
1469            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1470            HirRelationExpr::Get { typ, .. } => typ.clone(),
1471            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1472            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1473            HirRelationExpr::Project { input, outputs } => {
1474                let input_typ = input.typ(outers, params);
1475                RelationType::new(
1476                    outputs
1477                        .iter()
1478                        .map(|&i| input_typ.column_types[i].clone())
1479                        .collect(),
1480                )
1481            }
1482            HirRelationExpr::Map { input, scalars } => {
1483                let mut typ = input.typ(outers, params);
1484                for scalar in scalars {
1485                    typ.column_types.push(scalar.typ(outers, &typ, params));
1486                }
1487                typ
1488            }
1489            HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1490            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1491                input.typ(outers, params)
1492            }
1493            HirRelationExpr::Join {
1494                left, right, kind, ..
1495            } => {
1496                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1497                let right_nullable =
1498                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1499                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1500                    let nullable = t.nullable || left_nullable;
1501                    t.nullable(nullable)
1502                });
1503                let mut outers = outers.to_vec();
1504                outers.insert(0, RelationType::new(lt.clone().collect()));
1505                let rt = right
1506                    .typ(&outers, params)
1507                    .column_types
1508                    .into_iter()
1509                    .map(|t| {
1510                        let nullable = t.nullable || right_nullable;
1511                        t.nullable(nullable)
1512                    });
1513                RelationType::new(lt.chain(rt).collect())
1514            }
1515            HirRelationExpr::Reduce {
1516                input,
1517                group_key,
1518                aggregates,
1519                expected_group_size: _,
1520            } => {
1521                let input_typ = input.typ(outers, params);
1522                let mut column_types = group_key
1523                    .iter()
1524                    .map(|&i| input_typ.column_types[i].clone())
1525                    .collect::<Vec<_>>();
1526                for agg in aggregates {
1527                    column_types.push(agg.typ(outers, &input_typ, params));
1528                }
1529                // TODO(frank): add primary key information.
1530                RelationType::new(column_types)
1531            }
1532            // TODO(frank): check for removal; add primary key information.
1533            HirRelationExpr::Distinct { input }
1534            | HirRelationExpr::Negate { input }
1535            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1536            HirRelationExpr::Union { base, inputs } => {
1537                let mut base_cols = base.typ(outers, params).column_types;
1538                for input in inputs {
1539                    for (base_col, col) in base_cols
1540                        .iter_mut()
1541                        .zip_eq(input.typ(outers, params).column_types)
1542                    {
1543                        *base_col = base_col.union(&col).unwrap();
1544                    }
1545                }
1546                RelationType::new(base_cols)
1547            }
1548        })
1549    }
1550
1551    pub fn arity(&self) -> usize {
1552        match self {
1553            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1554            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1555            HirRelationExpr::Let { body, .. } => body.arity(),
1556            HirRelationExpr::LetRec { body, .. } => body.arity(),
1557            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1558            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1559            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1560            HirRelationExpr::Filter { input, .. }
1561            | HirRelationExpr::TopK { input, .. }
1562            | HirRelationExpr::Distinct { input }
1563            | HirRelationExpr::Negate { input }
1564            | HirRelationExpr::Threshold { input } => input.arity(),
1565            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1566            HirRelationExpr::Union { base, .. } => base.arity(),
1567            HirRelationExpr::Reduce {
1568                group_key,
1569                aggregates,
1570                ..
1571            } => group_key.len() + aggregates.len(),
1572        }
1573    }
1574
1575    /// If self is a constant, return the value and the type, otherwise `None`.
1576    pub fn as_const(&self) -> Option<(&Vec<Row>, &RelationType)> {
1577        match self {
1578            Self::Constant { rows, typ } => Some((rows, typ)),
1579            _ => None,
1580        }
1581    }
1582
1583    /// Reports whether this expression contains a column reference to its
1584    /// direct parent scope.
1585    pub fn is_correlated(&self) -> bool {
1586        let mut correlated = false;
1587        #[allow(deprecated)]
1588        self.visit_columns(0, &mut |depth, col| {
1589            if col.level > depth && col.level - depth == 1 {
1590                correlated = true;
1591            }
1592        });
1593        correlated
1594    }
1595
1596    pub fn is_join_identity(&self) -> bool {
1597        match self {
1598            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1599            _ => false,
1600        }
1601    }
1602
1603    pub fn project(self, outputs: Vec<usize>) -> Self {
1604        if outputs.iter().copied().eq(0..self.arity()) {
1605            // The projection is trivial. Suppress it.
1606            self
1607        } else {
1608            HirRelationExpr::Project {
1609                input: Box::new(self),
1610                outputs,
1611            }
1612        }
1613    }
1614
1615    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1616        if scalars.is_empty() {
1617            // The map is trivial. Suppress it.
1618            self
1619        } else if let HirRelationExpr::Map {
1620            scalars: old_scalars,
1621            input: _,
1622        } = &mut self
1623        {
1624            // Map applied to a map. Fuse the maps.
1625            old_scalars.extend(scalars);
1626            self
1627        } else {
1628            HirRelationExpr::Map {
1629                input: Box::new(self),
1630                scalars,
1631            }
1632        }
1633    }
1634
1635    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1636        if let HirRelationExpr::Filter {
1637            input: _,
1638            predicates,
1639        } = &mut self
1640        {
1641            predicates.extend(preds);
1642            predicates.sort();
1643            predicates.dedup();
1644            self
1645        } else {
1646            preds.sort();
1647            preds.dedup();
1648            HirRelationExpr::Filter {
1649                input: Box::new(self),
1650                predicates: preds,
1651            }
1652        }
1653    }
1654
1655    pub fn reduce(
1656        self,
1657        group_key: Vec<usize>,
1658        aggregates: Vec<AggregateExpr>,
1659        expected_group_size: Option<u64>,
1660    ) -> Self {
1661        HirRelationExpr::Reduce {
1662            input: Box::new(self),
1663            group_key,
1664            aggregates,
1665            expected_group_size,
1666        }
1667    }
1668
1669    pub fn top_k(
1670        self,
1671        group_key: Vec<usize>,
1672        order_key: Vec<ColumnOrder>,
1673        limit: Option<HirScalarExpr>,
1674        offset: HirScalarExpr,
1675        expected_group_size: Option<u64>,
1676    ) -> Self {
1677        HirRelationExpr::TopK {
1678            input: Box::new(self),
1679            group_key,
1680            order_key,
1681            limit,
1682            offset,
1683            expected_group_size,
1684        }
1685    }
1686
1687    pub fn negate(self) -> Self {
1688        if let HirRelationExpr::Negate { input } = self {
1689            *input
1690        } else {
1691            HirRelationExpr::Negate {
1692                input: Box::new(self),
1693            }
1694        }
1695    }
1696
1697    pub fn distinct(self) -> Self {
1698        if let HirRelationExpr::Distinct { .. } = self {
1699            self
1700        } else {
1701            HirRelationExpr::Distinct {
1702                input: Box::new(self),
1703            }
1704        }
1705    }
1706
1707    pub fn threshold(self) -> Self {
1708        if let HirRelationExpr::Threshold { .. } = self {
1709            self
1710        } else {
1711            HirRelationExpr::Threshold {
1712                input: Box::new(self),
1713            }
1714        }
1715    }
1716
1717    pub fn union(self, other: Self) -> Self {
1718        let mut terms = Vec::new();
1719        if let HirRelationExpr::Union { base, inputs } = self {
1720            terms.push(*base);
1721            terms.extend(inputs);
1722        } else {
1723            terms.push(self);
1724        }
1725        if let HirRelationExpr::Union { base, inputs } = other {
1726            terms.push(*base);
1727            terms.extend(inputs);
1728        } else {
1729            terms.push(other);
1730        }
1731        HirRelationExpr::Union {
1732            base: Box::new(terms.remove(0)),
1733            inputs: terms,
1734        }
1735    }
1736
1737    pub fn exists(self) -> HirScalarExpr {
1738        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1739    }
1740
1741    pub fn select(self) -> HirScalarExpr {
1742        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1743    }
1744
1745    pub fn join(
1746        self,
1747        mut right: HirRelationExpr,
1748        on: HirScalarExpr,
1749        kind: JoinKind,
1750    ) -> HirRelationExpr {
1751        if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1752        {
1753            // The join can be elided, but we need to adjust column references
1754            // on the right-hand side to account for the removal of the scope
1755            // introduced by the join.
1756            #[allow(deprecated)]
1757            right.visit_columns_mut(0, &mut |depth, col| {
1758                if col.level > depth {
1759                    col.level -= 1;
1760                }
1761            });
1762            right
1763        } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1764            self
1765        } else {
1766            HirRelationExpr::Join {
1767                left: Box::new(self),
1768                right: Box::new(right),
1769                on,
1770                kind,
1771            }
1772        }
1773    }
1774
1775    pub fn take(&mut self) -> HirRelationExpr {
1776        mem::replace(
1777            self,
1778            HirRelationExpr::constant(vec![], RelationType::new(Vec::new())),
1779        )
1780    }
1781
1782    #[deprecated = "Use `Visit::visit_post`."]
1783    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1784    where
1785        F: FnMut(&'a Self, usize),
1786    {
1787        #[allow(deprecated)]
1788        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1789                                                 depth: usize|
1790         -> Result<(), ()> {
1791            f(e, depth);
1792            Ok(())
1793        });
1794    }
1795
1796    #[deprecated = "Use `Visit::try_visit_post`."]
1797    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1798    where
1799        F: FnMut(&'a Self, usize) -> Result<(), E>,
1800    {
1801        #[allow(deprecated)]
1802        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1803            e.visit_fallible(depth, f)
1804        })?;
1805        f(self, depth)
1806    }
1807
1808    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1809    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1810    where
1811        F: FnMut(&'a Self, usize) -> Result<(), E>,
1812    {
1813        match self {
1814            HirRelationExpr::Constant { .. }
1815            | HirRelationExpr::Get { .. }
1816            | HirRelationExpr::CallTable { .. } => (),
1817            HirRelationExpr::Let { body, value, .. } => {
1818                f(value, depth)?;
1819                f(body, depth)?;
1820            }
1821            HirRelationExpr::LetRec {
1822                limit: _,
1823                bindings,
1824                body,
1825            } => {
1826                for (_, _, value, _) in bindings.iter() {
1827                    f(value, depth)?;
1828                }
1829                f(body, depth)?;
1830            }
1831            HirRelationExpr::Project { input, .. } => {
1832                f(input, depth)?;
1833            }
1834            HirRelationExpr::Map { input, .. } => {
1835                f(input, depth)?;
1836            }
1837            HirRelationExpr::Filter { input, .. } => {
1838                f(input, depth)?;
1839            }
1840            HirRelationExpr::Join { left, right, .. } => {
1841                f(left, depth)?;
1842                f(right, depth + 1)?;
1843            }
1844            HirRelationExpr::Reduce { input, .. } => {
1845                f(input, depth)?;
1846            }
1847            HirRelationExpr::Distinct { input } => {
1848                f(input, depth)?;
1849            }
1850            HirRelationExpr::TopK { input, .. } => {
1851                f(input, depth)?;
1852            }
1853            HirRelationExpr::Negate { input } => {
1854                f(input, depth)?;
1855            }
1856            HirRelationExpr::Threshold { input } => {
1857                f(input, depth)?;
1858            }
1859            HirRelationExpr::Union { base, inputs } => {
1860                f(base, depth)?;
1861                for input in inputs {
1862                    f(input, depth)?;
1863                }
1864            }
1865        }
1866        Ok(())
1867    }
1868
1869    #[deprecated = "Use `Visit::visit_mut_post` instead."]
1870    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1871    where
1872        F: FnMut(&mut Self, usize),
1873    {
1874        #[allow(deprecated)]
1875        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1876                                                     depth: usize|
1877         -> Result<(), ()> {
1878            f(e, depth);
1879            Ok(())
1880        });
1881    }
1882
1883    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1884    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1885    where
1886        F: FnMut(&mut Self, usize) -> Result<(), E>,
1887    {
1888        #[allow(deprecated)]
1889        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1890            e.visit_mut_fallible(depth, f)
1891        })?;
1892        f(self, depth)
1893    }
1894
1895    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1896    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1897    where
1898        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1899    {
1900        match self {
1901            HirRelationExpr::Constant { .. }
1902            | HirRelationExpr::Get { .. }
1903            | HirRelationExpr::CallTable { .. } => (),
1904            HirRelationExpr::Let { body, value, .. } => {
1905                f(value, depth)?;
1906                f(body, depth)?;
1907            }
1908            HirRelationExpr::LetRec {
1909                limit: _,
1910                bindings,
1911                body,
1912            } => {
1913                for (_, _, value, _) in bindings.iter_mut() {
1914                    f(value, depth)?;
1915                }
1916                f(body, depth)?;
1917            }
1918            HirRelationExpr::Project { input, .. } => {
1919                f(input, depth)?;
1920            }
1921            HirRelationExpr::Map { input, .. } => {
1922                f(input, depth)?;
1923            }
1924            HirRelationExpr::Filter { input, .. } => {
1925                f(input, depth)?;
1926            }
1927            HirRelationExpr::Join { left, right, .. } => {
1928                f(left, depth)?;
1929                f(right, depth + 1)?;
1930            }
1931            HirRelationExpr::Reduce { input, .. } => {
1932                f(input, depth)?;
1933            }
1934            HirRelationExpr::Distinct { input } => {
1935                f(input, depth)?;
1936            }
1937            HirRelationExpr::TopK { input, .. } => {
1938                f(input, depth)?;
1939            }
1940            HirRelationExpr::Negate { input } => {
1941                f(input, depth)?;
1942            }
1943            HirRelationExpr::Threshold { input } => {
1944                f(input, depth)?;
1945            }
1946            HirRelationExpr::Union { base, inputs } => {
1947                f(base, depth)?;
1948                for input in inputs {
1949                    f(input, depth)?;
1950                }
1951            }
1952        }
1953        Ok(())
1954    }
1955
1956    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1957    /// Visits all scalar expressions within the sub-tree of the given relation.
1958    ///
1959    /// The `depth` argument should indicate the subquery nesting depth of the expression,
1960    /// which will be incremented when entering the RHS of a join or a subquery and
1961    /// presented to the supplied function `f`.
1962    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1963    where
1964        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1965    {
1966        #[allow(deprecated)]
1967        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1968                                         depth: usize|
1969         -> Result<(), E> {
1970            match e {
1971                HirRelationExpr::Join { on, .. } => {
1972                    f(on, depth)?;
1973                }
1974                HirRelationExpr::Map { scalars, .. } => {
1975                    for scalar in scalars {
1976                        f(scalar, depth)?;
1977                    }
1978                }
1979                HirRelationExpr::CallTable { exprs, .. } => {
1980                    for expr in exprs {
1981                        f(expr, depth)?;
1982                    }
1983                }
1984                HirRelationExpr::Filter { predicates, .. } => {
1985                    for predicate in predicates {
1986                        f(predicate, depth)?;
1987                    }
1988                }
1989                HirRelationExpr::Reduce { aggregates, .. } => {
1990                    for aggregate in aggregates {
1991                        f(&aggregate.expr, depth)?;
1992                    }
1993                }
1994                HirRelationExpr::TopK { limit, offset, .. } => {
1995                    if let Some(limit) = limit {
1996                        f(limit, depth)?;
1997                    }
1998                    f(offset, depth)?;
1999                }
2000                HirRelationExpr::Union { .. }
2001                | HirRelationExpr::Let { .. }
2002                | HirRelationExpr::LetRec { .. }
2003                | HirRelationExpr::Project { .. }
2004                | HirRelationExpr::Distinct { .. }
2005                | HirRelationExpr::Negate { .. }
2006                | HirRelationExpr::Threshold { .. }
2007                | HirRelationExpr::Constant { .. }
2008                | HirRelationExpr::Get { .. } => (),
2009            }
2010            Ok(())
2011        })
2012    }
2013
2014    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2015    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2016    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2017    where
2018        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2019    {
2020        #[allow(deprecated)]
2021        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2022                                             depth: usize|
2023         -> Result<(), E> {
2024            match e {
2025                HirRelationExpr::Join { on, .. } => {
2026                    f(on, depth)?;
2027                }
2028                HirRelationExpr::Map { scalars, .. } => {
2029                    for scalar in scalars.iter_mut() {
2030                        f(scalar, depth)?;
2031                    }
2032                }
2033                HirRelationExpr::CallTable { exprs, .. } => {
2034                    for expr in exprs.iter_mut() {
2035                        f(expr, depth)?;
2036                    }
2037                }
2038                HirRelationExpr::Filter { predicates, .. } => {
2039                    for predicate in predicates.iter_mut() {
2040                        f(predicate, depth)?;
2041                    }
2042                }
2043                HirRelationExpr::Reduce { aggregates, .. } => {
2044                    for aggregate in aggregates.iter_mut() {
2045                        f(&mut aggregate.expr, depth)?;
2046                    }
2047                }
2048                HirRelationExpr::TopK { limit, offset, .. } => {
2049                    if let Some(limit) = limit {
2050                        f(limit, depth)?;
2051                    }
2052                    f(offset, depth)?;
2053                }
2054                HirRelationExpr::Union { .. }
2055                | HirRelationExpr::Let { .. }
2056                | HirRelationExpr::LetRec { .. }
2057                | HirRelationExpr::Project { .. }
2058                | HirRelationExpr::Distinct { .. }
2059                | HirRelationExpr::Negate { .. }
2060                | HirRelationExpr::Threshold { .. }
2061                | HirRelationExpr::Constant { .. }
2062                | HirRelationExpr::Get { .. } => (),
2063            }
2064            Ok(())
2065        })
2066    }
2067
2068    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2069    /// Visits the column references in this relation expression.
2070    ///
2071    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2072    /// which will be incremented when entering the RHS of a join or a subquery and
2073    /// presented to the supplied function `f`.
2074    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2075    where
2076        F: FnMut(usize, &ColumnRef),
2077    {
2078        #[allow(deprecated)]
2079        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2080                                                           depth: usize|
2081         -> Result<(), ()> {
2082            e.visit_columns(depth, f);
2083            Ok(())
2084        });
2085    }
2086
2087    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2088    /// Like `visit_columns`, but permits mutating the column references.
2089    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2090    where
2091        F: FnMut(usize, &mut ColumnRef),
2092    {
2093        #[allow(deprecated)]
2094        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2095                                                               depth: usize|
2096         -> Result<(), ()> {
2097            e.visit_columns_mut(depth, f);
2098            Ok(())
2099        });
2100    }
2101
2102    /// Replaces any parameter references in the expression with the
2103    /// corresponding datum from `params`.
2104    pub fn bind_parameters(
2105        &mut self,
2106        scx: &StatementContext,
2107        lifetime: QueryLifetime,
2108        params: &Params,
2109    ) -> Result<(), PlanError> {
2110        #[allow(deprecated)]
2111        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2112            e.bind_parameters(scx, lifetime, params)
2113        })
2114    }
2115
2116    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2117        let mut contains_parameters = false;
2118        #[allow(deprecated)]
2119        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2120            if e.contains_parameters() {
2121                contains_parameters = true;
2122            }
2123            Ok::<(), PlanError>(())
2124        })?;
2125        Ok(contains_parameters)
2126    }
2127
2128    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2129    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2130        #[allow(deprecated)]
2131        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2132                                                               depth: usize|
2133         -> Result<(), ()> {
2134            e.splice_parameters(params, depth);
2135            Ok(())
2136        });
2137    }
2138
2139    /// Constructs a constant collection from specific rows and schema.
2140    pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
2141        let rows = rows
2142            .into_iter()
2143            .map(move |datums| Row::pack_slice(&datums))
2144            .collect();
2145        HirRelationExpr::Constant { rows, typ }
2146    }
2147
2148    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2149    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2150    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2151    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2152    /// finishing into a trivial finishing.
2153    pub fn finish_maintained(
2154        &mut self,
2155        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2156        group_size_hints: GroupSizeHints,
2157    ) {
2158        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2159            let old_finishing = mem::replace(
2160                finishing,
2161                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2162            );
2163            *self = HirRelationExpr::top_k(
2164                std::mem::replace(
2165                    self,
2166                    HirRelationExpr::Constant {
2167                        rows: vec![],
2168                        typ: RelationType::new(Vec::new()),
2169                    },
2170                ),
2171                vec![],
2172                old_finishing.order_by,
2173                old_finishing.limit,
2174                old_finishing.offset,
2175                group_size_hints.limit_input_group_size,
2176            )
2177            .project(old_finishing.project);
2178        }
2179    }
2180
2181    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2182    ///
2183    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2184    /// parameter is not an HirScalarExpr anymore.)
2185    pub fn trivial_row_set_finishing_hir(
2186        arity: usize,
2187    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2188        RowSetFinishing {
2189            order_by: Vec::new(),
2190            limit: None,
2191            offset: HirScalarExpr::literal(Datum::Int64(0), ScalarType::Int64),
2192            project: (0..arity).collect(),
2193        }
2194    }
2195
2196    /// True if the finishing does nothing to any result set.
2197    ///
2198    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2199    /// parameter is not an HirScalarExpr anymore.)
2200    pub fn is_trivial_row_set_finishing_hir(
2201        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2202        arity: usize,
2203    ) -> bool {
2204        rsf.limit.is_none()
2205            && rsf.order_by.is_empty()
2206            && rsf
2207                .offset
2208                .clone()
2209                .try_into_literal_int64()
2210                .is_ok_and(|o| o == 0)
2211            && rsf.project.iter().copied().eq(0..arity)
2212    }
2213
2214    /// The HirRelationExpr is considered potentially expensive if and only if
2215    /// at least one of the following conditions is true:
2216    ///
2217    ///  - It contains at least one CallTable or a Reduce operator.
2218    ///  - It contains at least one HirScalarExpr with a function call.
2219    ///
2220    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2221    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2222    pub fn could_run_expensive_function(&self) -> bool {
2223        let mut result = false;
2224        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2225            use HirRelationExpr::*;
2226            use HirScalarExpr::*;
2227
2228            self.visit_children(|scalar: &HirScalarExpr| {
2229                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2230                    result |= match scalar {
2231                        Column(..)
2232                        | Literal(..)
2233                        | CallUnmaterializable(..)
2234                        | If { .. }
2235                        | Parameter(..)
2236                        | Select(..)
2237                        | Exists(..) => false,
2238                        // Function calls are considered expensive
2239                        CallUnary { .. }
2240                        | CallBinary { .. }
2241                        | CallVariadic { .. }
2242                        | Windowing(..) => true,
2243                    };
2244                }) {
2245                    // Conservatively set `true` on RecursionLimitError.
2246                    result = true;
2247                }
2248            });
2249
2250            // CallTable has a table function; Reduce has an aggregate function.
2251            // Other constructs use MirScalarExpr to run a function
2252            result |= matches!(e, CallTable { .. } | Reduce { .. });
2253        }) {
2254            // Conservatively set `true` on RecursionLimitError.
2255            result = true;
2256        }
2257
2258        result
2259    }
2260
2261    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2262    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2263        let mut contains = false;
2264        self.visit_post(&mut |expr| {
2265            expr.visit_children(|expr: &HirScalarExpr| {
2266                contains = contains || expr.contains_temporal()
2267            })
2268        })?;
2269        Ok(contains)
2270    }
2271}
2272
2273impl CollectionPlan for HirRelationExpr {
2274    // !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2275    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2276    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2277        if let Self::Get {
2278            id: Id::Global(id), ..
2279        } = self
2280        {
2281            out.insert(*id);
2282        }
2283        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2284    }
2285}
2286
2287impl VisitChildren<Self> for HirRelationExpr {
2288    fn visit_children<F>(&self, mut f: F)
2289    where
2290        F: FnMut(&Self),
2291    {
2292        // subqueries of type HirRelationExpr might be wrapped in
2293        // Exists or Select variants within HirScalarExpr trees
2294        // attached at the current node, and we want to visit them as well
2295        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2296            #[allow(deprecated)]
2297            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2298                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2299                    f(expr.as_ref())
2300                }
2301                _ => (),
2302            });
2303        });
2304
2305        use HirRelationExpr::*;
2306        match self {
2307            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2308            Let {
2309                name: _,
2310                id: _,
2311                value,
2312                body,
2313            } => {
2314                f(value);
2315                f(body);
2316            }
2317            LetRec {
2318                limit: _,
2319                bindings,
2320                body,
2321            } => {
2322                for (_, _, value, _) in bindings.iter() {
2323                    f(value);
2324                }
2325                f(body);
2326            }
2327            Project { input, outputs: _ } => f(input),
2328            Map { input, scalars: _ } => {
2329                f(input);
2330            }
2331            CallTable { func: _, exprs: _ } => (),
2332            Filter {
2333                input,
2334                predicates: _,
2335            } => {
2336                f(input);
2337            }
2338            Join {
2339                left,
2340                right,
2341                on: _,
2342                kind: _,
2343            } => {
2344                f(left);
2345                f(right);
2346            }
2347            Reduce {
2348                input,
2349                group_key: _,
2350                aggregates: _,
2351                expected_group_size: _,
2352            } => {
2353                f(input);
2354            }
2355            Distinct { input }
2356            | TopK {
2357                input,
2358                group_key: _,
2359                order_key: _,
2360                limit: _,
2361                offset: _,
2362                expected_group_size: _,
2363            }
2364            | Negate { input }
2365            | Threshold { input } => {
2366                f(input);
2367            }
2368            Union { base, inputs } => {
2369                f(base);
2370                for input in inputs {
2371                    f(input);
2372                }
2373            }
2374        }
2375    }
2376
2377    fn visit_mut_children<F>(&mut self, mut f: F)
2378    where
2379        F: FnMut(&mut Self),
2380    {
2381        // subqueries of type HirRelationExpr might be wrapped in
2382        // Exists or Select variants within HirScalarExpr trees
2383        // attached at the current node, and we want to visit them as well
2384        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2385            #[allow(deprecated)]
2386            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2387                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2388                    f(expr.as_mut())
2389                }
2390                _ => (),
2391            });
2392        });
2393
2394        use HirRelationExpr::*;
2395        match self {
2396            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2397            Let {
2398                name: _,
2399                id: _,
2400                value,
2401                body,
2402            } => {
2403                f(value);
2404                f(body);
2405            }
2406            LetRec {
2407                limit: _,
2408                bindings,
2409                body,
2410            } => {
2411                for (_, _, value, _) in bindings.iter_mut() {
2412                    f(value);
2413                }
2414                f(body);
2415            }
2416            Project { input, outputs: _ } => f(input),
2417            Map { input, scalars: _ } => {
2418                f(input);
2419            }
2420            CallTable { func: _, exprs: _ } => (),
2421            Filter {
2422                input,
2423                predicates: _,
2424            } => {
2425                f(input);
2426            }
2427            Join {
2428                left,
2429                right,
2430                on: _,
2431                kind: _,
2432            } => {
2433                f(left);
2434                f(right);
2435            }
2436            Reduce {
2437                input,
2438                group_key: _,
2439                aggregates: _,
2440                expected_group_size: _,
2441            } => {
2442                f(input);
2443            }
2444            Distinct { input }
2445            | TopK {
2446                input,
2447                group_key: _,
2448                order_key: _,
2449                limit: _,
2450                offset: _,
2451                expected_group_size: _,
2452            }
2453            | Negate { input }
2454            | Threshold { input } => {
2455                f(input);
2456            }
2457            Union { base, inputs } => {
2458                f(base);
2459                for input in inputs {
2460                    f(input);
2461                }
2462            }
2463        }
2464    }
2465
2466    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2467    where
2468        F: FnMut(&Self) -> Result<(), E>,
2469        E: From<RecursionLimitError>,
2470    {
2471        // subqueries of type HirRelationExpr might be wrapped in
2472        // Exists or Select variants within HirScalarExpr trees
2473        // attached at the current node, and we want to visit them as well
2474        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2475            Visit::try_visit_post(expr, &mut |expr| match expr {
2476                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2477                    f(expr.as_ref())
2478                }
2479                _ => Ok(()),
2480            })
2481        })?;
2482
2483        use HirRelationExpr::*;
2484        match self {
2485            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2486            Let {
2487                name: _,
2488                id: _,
2489                value,
2490                body,
2491            } => {
2492                f(value)?;
2493                f(body)?;
2494            }
2495            LetRec {
2496                limit: _,
2497                bindings,
2498                body,
2499            } => {
2500                for (_, _, value, _) in bindings.iter() {
2501                    f(value)?;
2502                }
2503                f(body)?;
2504            }
2505            Project { input, outputs: _ } => f(input)?,
2506            Map { input, scalars: _ } => {
2507                f(input)?;
2508            }
2509            CallTable { func: _, exprs: _ } => (),
2510            Filter {
2511                input,
2512                predicates: _,
2513            } => {
2514                f(input)?;
2515            }
2516            Join {
2517                left,
2518                right,
2519                on: _,
2520                kind: _,
2521            } => {
2522                f(left)?;
2523                f(right)?;
2524            }
2525            Reduce {
2526                input,
2527                group_key: _,
2528                aggregates: _,
2529                expected_group_size: _,
2530            } => {
2531                f(input)?;
2532            }
2533            Distinct { input }
2534            | TopK {
2535                input,
2536                group_key: _,
2537                order_key: _,
2538                limit: _,
2539                offset: _,
2540                expected_group_size: _,
2541            }
2542            | Negate { input }
2543            | Threshold { input } => {
2544                f(input)?;
2545            }
2546            Union { base, inputs } => {
2547                f(base)?;
2548                for input in inputs {
2549                    f(input)?;
2550                }
2551            }
2552        }
2553        Ok(())
2554    }
2555
2556    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2557    where
2558        F: FnMut(&mut Self) -> Result<(), E>,
2559        E: From<RecursionLimitError>,
2560    {
2561        // subqueries of type HirRelationExpr might be wrapped in
2562        // Exists or Select variants within HirScalarExpr trees
2563        // attached at the current node, and we want to visit them as well
2564        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2565            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2566                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2567                    f(expr.as_mut())
2568                }
2569                _ => Ok(()),
2570            })
2571        })?;
2572
2573        use HirRelationExpr::*;
2574        match self {
2575            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2576            Let {
2577                name: _,
2578                id: _,
2579                value,
2580                body,
2581            } => {
2582                f(value)?;
2583                f(body)?;
2584            }
2585            LetRec {
2586                limit: _,
2587                bindings,
2588                body,
2589            } => {
2590                for (_, _, value, _) in bindings.iter_mut() {
2591                    f(value)?;
2592                }
2593                f(body)?;
2594            }
2595            Project { input, outputs: _ } => f(input)?,
2596            Map { input, scalars: _ } => {
2597                f(input)?;
2598            }
2599            CallTable { func: _, exprs: _ } => (),
2600            Filter {
2601                input,
2602                predicates: _,
2603            } => {
2604                f(input)?;
2605            }
2606            Join {
2607                left,
2608                right,
2609                on: _,
2610                kind: _,
2611            } => {
2612                f(left)?;
2613                f(right)?;
2614            }
2615            Reduce {
2616                input,
2617                group_key: _,
2618                aggregates: _,
2619                expected_group_size: _,
2620            } => {
2621                f(input)?;
2622            }
2623            Distinct { input }
2624            | TopK {
2625                input,
2626                group_key: _,
2627                order_key: _,
2628                limit: _,
2629                offset: _,
2630                expected_group_size: _,
2631            }
2632            | Negate { input }
2633            | Threshold { input } => {
2634                f(input)?;
2635            }
2636            Union { base, inputs } => {
2637                f(base)?;
2638                for input in inputs {
2639                    f(input)?;
2640                }
2641            }
2642        }
2643        Ok(())
2644    }
2645}
2646
2647impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2648    fn visit_children<F>(&self, mut f: F)
2649    where
2650        F: FnMut(&HirScalarExpr),
2651    {
2652        use HirRelationExpr::*;
2653        match self {
2654            Constant { rows: _, typ: _ }
2655            | Get { id: _, typ: _ }
2656            | Let {
2657                name: _,
2658                id: _,
2659                value: _,
2660                body: _,
2661            }
2662            | LetRec {
2663                limit: _,
2664                bindings: _,
2665                body: _,
2666            }
2667            | Project {
2668                input: _,
2669                outputs: _,
2670            } => (),
2671            Map { input: _, scalars } => {
2672                for scalar in scalars {
2673                    f(scalar);
2674                }
2675            }
2676            CallTable { func: _, exprs } => {
2677                for expr in exprs {
2678                    f(expr);
2679                }
2680            }
2681            Filter {
2682                input: _,
2683                predicates,
2684            } => {
2685                for predicate in predicates {
2686                    f(predicate);
2687                }
2688            }
2689            Join {
2690                left: _,
2691                right: _,
2692                on,
2693                kind: _,
2694            } => f(on),
2695            Reduce {
2696                input: _,
2697                group_key: _,
2698                aggregates,
2699                expected_group_size: _,
2700            } => {
2701                for aggregate in aggregates {
2702                    f(aggregate.expr.as_ref());
2703                }
2704            }
2705            TopK {
2706                input: _,
2707                group_key: _,
2708                order_key: _,
2709                limit,
2710                offset,
2711                expected_group_size: _,
2712            } => {
2713                if let Some(limit) = limit {
2714                    f(limit)
2715                }
2716                f(offset)
2717            }
2718            Distinct { input: _ }
2719            | Negate { input: _ }
2720            | Threshold { input: _ }
2721            | Union { base: _, inputs: _ } => (),
2722        }
2723    }
2724
2725    fn visit_mut_children<F>(&mut self, mut f: F)
2726    where
2727        F: FnMut(&mut HirScalarExpr),
2728    {
2729        use HirRelationExpr::*;
2730        match self {
2731            Constant { rows: _, typ: _ }
2732            | Get { id: _, typ: _ }
2733            | Let {
2734                name: _,
2735                id: _,
2736                value: _,
2737                body: _,
2738            }
2739            | LetRec {
2740                limit: _,
2741                bindings: _,
2742                body: _,
2743            }
2744            | Project {
2745                input: _,
2746                outputs: _,
2747            } => (),
2748            Map { input: _, scalars } => {
2749                for scalar in scalars {
2750                    f(scalar);
2751                }
2752            }
2753            CallTable { func: _, exprs } => {
2754                for expr in exprs {
2755                    f(expr);
2756                }
2757            }
2758            Filter {
2759                input: _,
2760                predicates,
2761            } => {
2762                for predicate in predicates {
2763                    f(predicate);
2764                }
2765            }
2766            Join {
2767                left: _,
2768                right: _,
2769                on,
2770                kind: _,
2771            } => f(on),
2772            Reduce {
2773                input: _,
2774                group_key: _,
2775                aggregates,
2776                expected_group_size: _,
2777            } => {
2778                for aggregate in aggregates {
2779                    f(aggregate.expr.as_mut());
2780                }
2781            }
2782            TopK {
2783                input: _,
2784                group_key: _,
2785                order_key: _,
2786                limit,
2787                offset,
2788                expected_group_size: _,
2789            } => {
2790                if let Some(limit) = limit {
2791                    f(limit)
2792                }
2793                f(offset)
2794            }
2795            Distinct { input: _ }
2796            | Negate { input: _ }
2797            | Threshold { input: _ }
2798            | Union { base: _, inputs: _ } => (),
2799        }
2800    }
2801
2802    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2803    where
2804        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2805        E: From<RecursionLimitError>,
2806    {
2807        use HirRelationExpr::*;
2808        match self {
2809            Constant { rows: _, typ: _ }
2810            | Get { id: _, typ: _ }
2811            | Let {
2812                name: _,
2813                id: _,
2814                value: _,
2815                body: _,
2816            }
2817            | LetRec {
2818                limit: _,
2819                bindings: _,
2820                body: _,
2821            }
2822            | Project {
2823                input: _,
2824                outputs: _,
2825            } => (),
2826            Map { input: _, scalars } => {
2827                for scalar in scalars {
2828                    f(scalar)?;
2829                }
2830            }
2831            CallTable { func: _, exprs } => {
2832                for expr in exprs {
2833                    f(expr)?;
2834                }
2835            }
2836            Filter {
2837                input: _,
2838                predicates,
2839            } => {
2840                for predicate in predicates {
2841                    f(predicate)?;
2842                }
2843            }
2844            Join {
2845                left: _,
2846                right: _,
2847                on,
2848                kind: _,
2849            } => f(on)?,
2850            Reduce {
2851                input: _,
2852                group_key: _,
2853                aggregates,
2854                expected_group_size: _,
2855            } => {
2856                for aggregate in aggregates {
2857                    f(aggregate.expr.as_ref())?;
2858                }
2859            }
2860            TopK {
2861                input: _,
2862                group_key: _,
2863                order_key: _,
2864                limit,
2865                offset,
2866                expected_group_size: _,
2867            } => {
2868                if let Some(limit) = limit {
2869                    f(limit)?
2870                }
2871                f(offset)?
2872            }
2873            Distinct { input: _ }
2874            | Negate { input: _ }
2875            | Threshold { input: _ }
2876            | Union { base: _, inputs: _ } => (),
2877        }
2878        Ok(())
2879    }
2880
2881    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2882    where
2883        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2884        E: From<RecursionLimitError>,
2885    {
2886        use HirRelationExpr::*;
2887        match self {
2888            Constant { rows: _, typ: _ }
2889            | Get { id: _, typ: _ }
2890            | Let {
2891                name: _,
2892                id: _,
2893                value: _,
2894                body: _,
2895            }
2896            | LetRec {
2897                limit: _,
2898                bindings: _,
2899                body: _,
2900            }
2901            | Project {
2902                input: _,
2903                outputs: _,
2904            } => (),
2905            Map { input: _, scalars } => {
2906                for scalar in scalars {
2907                    f(scalar)?;
2908                }
2909            }
2910            CallTable { func: _, exprs } => {
2911                for expr in exprs {
2912                    f(expr)?;
2913                }
2914            }
2915            Filter {
2916                input: _,
2917                predicates,
2918            } => {
2919                for predicate in predicates {
2920                    f(predicate)?;
2921                }
2922            }
2923            Join {
2924                left: _,
2925                right: _,
2926                on,
2927                kind: _,
2928            } => f(on)?,
2929            Reduce {
2930                input: _,
2931                group_key: _,
2932                aggregates,
2933                expected_group_size: _,
2934            } => {
2935                for aggregate in aggregates {
2936                    f(aggregate.expr.as_mut())?;
2937                }
2938            }
2939            TopK {
2940                input: _,
2941                group_key: _,
2942                order_key: _,
2943                limit,
2944                offset,
2945                expected_group_size: _,
2946            } => {
2947                if let Some(limit) = limit {
2948                    f(limit)?
2949                }
2950                f(offset)?
2951            }
2952            Distinct { input: _ }
2953            | Negate { input: _ }
2954            | Threshold { input: _ }
2955            | Union { base: _, inputs: _ } => (),
2956        }
2957        Ok(())
2958    }
2959}
2960
2961impl HirScalarExpr {
2962    pub fn name(&self) -> Option<Arc<str>> {
2963        use HirScalarExpr::*;
2964        match self {
2965            Column(_, name)
2966            | Parameter(_, name)
2967            | Literal(_, _, name)
2968            | CallUnmaterializable(_, name)
2969            | CallUnary { name, .. }
2970            | CallBinary { name, .. }
2971            | CallVariadic { name, .. }
2972            | If { name, .. }
2973            | Exists(_, name)
2974            | Select(_, name)
2975            | Windowing(_, name) => name.0.clone(),
2976        }
2977    }
2978
2979    /// Replaces any parameter references in the expression with the
2980    /// corresponding datum in `params`.
2981    pub fn bind_parameters(
2982        &mut self,
2983        scx: &StatementContext,
2984        lifetime: QueryLifetime,
2985        params: &Params,
2986    ) -> Result<(), PlanError> {
2987        #[allow(deprecated)]
2988        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
2989            if let HirScalarExpr::Parameter(n, name) = e {
2990                let datum = match params.datums.iter().nth(*n - 1) {
2991                    None => return Err(PlanError::UnknownParameter(*n)),
2992                    Some(datum) => datum,
2993                };
2994                let scalar_type = &params.execute_types[*n - 1];
2995                let row = Row::pack([datum]);
2996                let column_type = scalar_type.clone().nullable(datum.is_null());
2997
2998                let name = if let Some(name) = &name.0 {
2999                    Some(Arc::clone(name))
3000                } else {
3001                    Some(Arc::from(format!("${n}")))
3002                };
3003
3004                let qcx = QueryContext::root(scx, lifetime);
3005                let ecx = execute_expr_context(&qcx);
3006
3007                *e = plan_cast(
3008                    &ecx,
3009                    *EXECUTE_CAST_CONTEXT,
3010                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3011                    &params.expected_types[*n - 1],
3012                )
3013                .expect("checked in plan_params");
3014            }
3015            Ok(())
3016        })
3017    }
3018
3019    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3020    /// replaced with the corresponding expression fragment from `params` rather
3021    /// than a datum.
3022    ///
3023    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3024    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3025    /// in `self` that refer to invalid indices of `params` will cause a panic.
3026    ///
3027    /// Column references in parameters will be corrected to account for the
3028    /// depth at which they are spliced.
3029    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3030        #[allow(deprecated)]
3031        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3032                                                        e: &mut HirScalarExpr|
3033         -> Result<(), ()> {
3034            if let HirScalarExpr::Parameter(i, _name) = e {
3035                *e = params[*i - 1].clone();
3036                // Correct any column references in the parameter expression for
3037                // its new depth.
3038                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3039                    if col.level >= d {
3040                        col.level += depth
3041                    }
3042                });
3043            }
3044            Ok(())
3045        });
3046    }
3047
3048    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3049    pub fn contains_temporal(&self) -> bool {
3050        let mut contains = false;
3051        #[allow(deprecated)]
3052        self.visit_post_nolimit(&mut |e| {
3053            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3054                contains = true;
3055            }
3056        });
3057        contains
3058    }
3059
3060    /// Constructs an unnamed column reference in the current scope.
3061    /// Use [`HirScalarExpr::named_column`] when a name is known.
3062    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3063    pub fn column(index: usize) -> HirScalarExpr {
3064        HirScalarExpr::Column(
3065            ColumnRef {
3066                level: 0,
3067                column: index,
3068            },
3069            TreatAsEqual(None),
3070        )
3071    }
3072
3073    /// Constructs an unnamed column reference.
3074    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3075        HirScalarExpr::Column(cr, TreatAsEqual(None))
3076    }
3077
3078    /// Constructs a named column reference.
3079    /// Names are interned by a `NameManager`.
3080    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3081        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3082    }
3083
3084    pub fn parameter(n: usize) -> HirScalarExpr {
3085        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3086    }
3087
3088    pub fn literal(datum: Datum, scalar_type: ScalarType) -> HirScalarExpr {
3089        let col_type = scalar_type.nullable(datum.is_null());
3090        soft_assert_or_log!(datum.is_instance_of(&col_type), "type is correct");
3091        let row = Row::pack([datum]);
3092        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3093    }
3094
3095    pub fn literal_true() -> HirScalarExpr {
3096        HirScalarExpr::literal(Datum::True, ScalarType::Bool)
3097    }
3098
3099    pub fn literal_false() -> HirScalarExpr {
3100        HirScalarExpr::literal(Datum::False, ScalarType::Bool)
3101    }
3102
3103    pub fn literal_null(scalar_type: ScalarType) -> HirScalarExpr {
3104        HirScalarExpr::literal(Datum::Null, scalar_type)
3105    }
3106
3107    pub fn literal_1d_array(
3108        datums: Vec<Datum>,
3109        element_scalar_type: ScalarType,
3110    ) -> Result<HirScalarExpr, PlanError> {
3111        let scalar_type = match element_scalar_type {
3112            ScalarType::Array(_) => {
3113                sql_bail!("cannot build array from array type");
3114            }
3115            typ => ScalarType::Array(Box::new(typ)).nullable(false),
3116        };
3117
3118        let mut row = Row::default();
3119        row.packer()
3120            .try_push_array(
3121                &[ArrayDimension {
3122                    lower_bound: 1,
3123                    length: datums.len(),
3124                }],
3125                datums,
3126            )
3127            .expect("array constructed to be valid");
3128
3129        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3130    }
3131
3132    pub fn as_literal(&self) -> Option<Datum<'_>> {
3133        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3134            Some(row.unpack_first())
3135        } else {
3136            None
3137        }
3138    }
3139
3140    pub fn is_literal_true(&self) -> bool {
3141        Some(Datum::True) == self.as_literal()
3142    }
3143
3144    pub fn is_literal_false(&self) -> bool {
3145        Some(Datum::False) == self.as_literal()
3146    }
3147
3148    pub fn is_literal_null(&self) -> bool {
3149        Some(Datum::Null) == self.as_literal()
3150    }
3151
3152    /// Return true iff `self` consists only of literals, materializable function calls, and
3153    /// if-else statements.
3154    pub fn is_constant(&self) -> bool {
3155        let mut worklist = vec![self];
3156        while let Some(expr) = worklist.pop() {
3157            match expr {
3158                Self::Literal(..) => {
3159                    // leaf node, do nothing
3160                }
3161                Self::CallUnary { expr, .. } => {
3162                    worklist.push(expr);
3163                }
3164                Self::CallBinary {
3165                    func: _,
3166                    expr1,
3167                    expr2,
3168                    name: _,
3169                } => {
3170                    worklist.push(expr1);
3171                    worklist.push(expr2);
3172                }
3173                Self::CallVariadic {
3174                    func: _,
3175                    exprs,
3176                    name: _,
3177                } => {
3178                    worklist.extend(exprs.iter());
3179                }
3180                // (CallUnmaterializable is not allowed)
3181                Self::If {
3182                    cond,
3183                    then,
3184                    els,
3185                    name: _,
3186                } => {
3187                    worklist.push(cond);
3188                    worklist.push(then);
3189                    worklist.push(els);
3190                }
3191                _ => {
3192                    return false; // Any other node makes `self` non-constant.
3193                }
3194            }
3195        }
3196        true
3197    }
3198
3199    pub fn call_unary(self, func: UnaryFunc) -> Self {
3200        HirScalarExpr::CallUnary {
3201            func,
3202            expr: Box::new(self),
3203            name: NameMetadata::default(),
3204        }
3205    }
3206
3207    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
3208        HirScalarExpr::CallBinary {
3209            func,
3210            expr1: Box::new(self),
3211            expr2: Box::new(other),
3212            name: NameMetadata::default(),
3213        }
3214    }
3215
3216    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3217        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3218    }
3219
3220    pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3221        HirScalarExpr::CallVariadic {
3222            func,
3223            exprs,
3224            name: NameMetadata::default(),
3225        }
3226    }
3227
3228    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3229        HirScalarExpr::If {
3230            cond: Box::new(cond),
3231            then: Box::new(then),
3232            els: Box::new(els),
3233            name: NameMetadata::default(),
3234        }
3235    }
3236
3237    pub fn windowing(expr: WindowExpr) -> Self {
3238        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3239    }
3240
3241    pub fn or(self, other: Self) -> Self {
3242        HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3243    }
3244
3245    pub fn and(self, other: Self) -> Self {
3246        HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3247    }
3248
3249    pub fn not(self) -> Self {
3250        self.call_unary(UnaryFunc::Not(func::Not))
3251    }
3252
3253    pub fn call_is_null(self) -> Self {
3254        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3255    }
3256
3257    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3258    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3259        match args.len() {
3260            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3261            1 => args.swap_remove(0),
3262            _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3263        }
3264    }
3265
3266    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3267    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3268        match args.len() {
3269            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3270            1 => args.swap_remove(0),
3271            _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3272        }
3273    }
3274
3275    pub fn take(&mut self) -> Self {
3276        mem::replace(self, HirScalarExpr::literal_null(ScalarType::String))
3277    }
3278
3279    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3280    /// Visits the column references in this scalar expression.
3281    ///
3282    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3283    /// which will be incremented with each subquery entered and presented to the supplied
3284    /// function `f`.
3285    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3286    where
3287        F: FnMut(usize, &ColumnRef),
3288    {
3289        #[allow(deprecated)]
3290        let _ = self.visit_recursively(depth, &mut |depth: usize,
3291                                                    e: &HirScalarExpr|
3292         -> Result<(), ()> {
3293            if let HirScalarExpr::Column(col, _name) = e {
3294                f(depth, col)
3295            }
3296            Ok(())
3297        });
3298    }
3299
3300    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3301    /// Like `visit_columns`, but permits mutating the column references.
3302    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3303    where
3304        F: FnMut(usize, &mut ColumnRef),
3305    {
3306        #[allow(deprecated)]
3307        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3308                                                        e: &mut HirScalarExpr|
3309         -> Result<(), ()> {
3310            if let HirScalarExpr::Column(col, _name) = e {
3311                f(depth, col)
3312            }
3313            Ok(())
3314        });
3315    }
3316
3317    /// Visits those column references in this scalar expression that refer to the root
3318    /// level. These include column references that are at the root level, as well as column
3319    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3320    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3321    /// "root level" to be `self`'s level.)
3322    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3323    where
3324        F: FnMut(usize),
3325    {
3326        #[allow(deprecated)]
3327        let _ = self.visit_recursively(0, &mut |depth: usize,
3328                                                e: &HirScalarExpr|
3329         -> Result<(), ()> {
3330            if let HirScalarExpr::Column(col, _name) = e {
3331                if col.level == depth {
3332                    f(col.column)
3333                }
3334            }
3335            Ok(())
3336        });
3337    }
3338
3339    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3340    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3341    where
3342        F: FnMut(&mut usize),
3343    {
3344        #[allow(deprecated)]
3345        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3346                                                    e: &mut HirScalarExpr|
3347         -> Result<(), ()> {
3348            if let HirScalarExpr::Column(col, _name) = e {
3349                if col.level == depth {
3350                    f(&mut col.column)
3351                }
3352            }
3353            Ok(())
3354        });
3355    }
3356
3357    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3358    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3359    /// in them. It takes the current depth of the expression and increases it when
3360    /// entering a subquery.
3361    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3362    where
3363        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3364    {
3365        match self {
3366            HirScalarExpr::Literal(..)
3367            | HirScalarExpr::Parameter(..)
3368            | HirScalarExpr::CallUnmaterializable(..)
3369            | HirScalarExpr::Column(..) => (),
3370            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3371            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3372                expr1.visit_recursively(depth, f)?;
3373                expr2.visit_recursively(depth, f)?;
3374            }
3375            HirScalarExpr::CallVariadic { exprs, .. } => {
3376                for expr in exprs {
3377                    expr.visit_recursively(depth, f)?;
3378                }
3379            }
3380            HirScalarExpr::If {
3381                cond,
3382                then,
3383                els,
3384                name: _,
3385            } => {
3386                cond.visit_recursively(depth, f)?;
3387                then.visit_recursively(depth, f)?;
3388                els.visit_recursively(depth, f)?;
3389            }
3390            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3391                #[allow(deprecated)]
3392                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3393                    e.visit_recursively(depth, f)
3394                })?;
3395            }
3396            HirScalarExpr::Windowing(expr, _name) => {
3397                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3398            }
3399        }
3400        f(depth, self)
3401    }
3402
3403    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3404    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3405    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3406    where
3407        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3408    {
3409        match self {
3410            HirScalarExpr::Literal(..)
3411            | HirScalarExpr::Parameter(..)
3412            | HirScalarExpr::CallUnmaterializable(..)
3413            | HirScalarExpr::Column(..) => (),
3414            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3415            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3416                expr1.visit_recursively_mut(depth, f)?;
3417                expr2.visit_recursively_mut(depth, f)?;
3418            }
3419            HirScalarExpr::CallVariadic { exprs, .. } => {
3420                for expr in exprs {
3421                    expr.visit_recursively_mut(depth, f)?;
3422                }
3423            }
3424            HirScalarExpr::If {
3425                cond,
3426                then,
3427                els,
3428                name: _,
3429            } => {
3430                cond.visit_recursively_mut(depth, f)?;
3431                then.visit_recursively_mut(depth, f)?;
3432                els.visit_recursively_mut(depth, f)?;
3433            }
3434            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3435                #[allow(deprecated)]
3436                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3437                    e.visit_recursively_mut(depth, f)
3438                })?;
3439            }
3440            HirScalarExpr::Windowing(expr, _name) => {
3441                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3442            }
3443        }
3444        f(depth, self)
3445    }
3446
3447    /// Attempts to simplify self into a literal.
3448    ///
3449    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3450    /// an evaluation error occurs during simplification, or if self contains
3451    /// - a subquery
3452    /// - a column reference to an outer level
3453    /// - a parameter
3454    /// - a window function call
3455    fn simplify_to_literal(self) -> Option<Row> {
3456        let mut expr = self.lower_uncorrelated().ok()?;
3457        expr.reduce(&[]);
3458        match expr {
3459            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3460            _ => None,
3461        }
3462    }
3463
3464    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3465    /// or an evaluation error occurs during simplification), it returns
3466    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3467    ///
3468    /// The returned error is an _internal_ error if the expression contains
3469    /// - a subquery
3470    /// - a column reference to an outer level
3471    /// - a parameter
3472    /// - a window function call
3473    ///
3474    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3475    /// msg.
3476    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3477        let mut expr = self.lower_uncorrelated().map_err(|err| {
3478            PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3479        })?;
3480        expr.reduce(&[]);
3481        match expr {
3482            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3483            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3484                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3485            ),
3486            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3487                "Not a constant".to_string(),
3488            )),
3489        }
3490    }
3491
3492    /// Attempts to simplify this expression to a literal 64-bit integer.
3493    ///
3494    /// Returns `None` if this expression cannot be simplified, e.g. because it
3495    /// contains non-literal values.
3496    ///
3497    /// # Panics
3498    ///
3499    /// Panics if this expression does not have type [`ScalarType::Int64`].
3500    pub fn into_literal_int64(self) -> Option<i64> {
3501        self.simplify_to_literal().and_then(|row| {
3502            let datum = row.unpack_first();
3503            if datum.is_null() {
3504                None
3505            } else {
3506                Some(datum.unwrap_int64())
3507            }
3508        })
3509    }
3510
3511    /// Attempts to simplify this expression to a literal string.
3512    ///
3513    /// Returns `None` if this expression cannot be simplified, e.g. because it
3514    /// contains non-literal values.
3515    ///
3516    /// # Panics
3517    ///
3518    /// Panics if this expression does not have type [`ScalarType::String`].
3519    pub fn into_literal_string(self) -> Option<String> {
3520        self.simplify_to_literal().and_then(|row| {
3521            let datum = row.unpack_first();
3522            if datum.is_null() {
3523                None
3524            } else {
3525                Some(datum.unwrap_str().to_owned())
3526            }
3527        })
3528    }
3529
3530    /// Attempts to simplify this expression to a literal MzTimestamp.
3531    ///
3532    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3533    /// simplified, e.g. because it contains non-literal values or a cast fails.
3534    ///
3535    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3536    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3537    ///
3538    /// # Panics
3539    ///
3540    /// Panics if this expression does not have type [`ScalarType::MzTimestamp`].
3541    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3542        self.simplify_to_literal().and_then(|row| {
3543            let datum = row.unpack_first();
3544            if datum.is_null() {
3545                None
3546            } else {
3547                Some(datum.unwrap_mz_timestamp())
3548            }
3549        })
3550    }
3551
3552    /// Attempts to simplify this expression of [`ScalarType::Int64`] to a literal Int64 and
3553    /// returns it as an i64.
3554    ///
3555    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3556    /// - it's not a constant expression (as determined by `is_constant`)
3557    /// - evaluates to null
3558    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3559    ///
3560    /// # Panics
3561    ///
3562    /// Panics if this expression does not have type [`ScalarType::Int64`].
3563    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3564        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3565        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3566        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3567        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3568        // preconditions also of all the other into_literal_... functions.)
3569        if !self.is_constant() {
3570            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3571                "Expected a constant expression, got {}",
3572                self
3573            )));
3574        }
3575        self.clone()
3576            .simplify_to_literal_with_result()
3577            .and_then(|row| {
3578                let datum = row.unpack_first();
3579                if datum.is_null() {
3580                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3581                        "Expected an expression that evaluates to a non-null value, got {}",
3582                        self
3583                    )))
3584                } else {
3585                    Ok(datum.unwrap_int64())
3586                }
3587            })
3588    }
3589
3590    pub fn contains_parameters(&self) -> bool {
3591        let mut contains_parameters = false;
3592        #[allow(deprecated)]
3593        let _ = self.visit_recursively(0, &mut |_depth: usize,
3594                                                expr: &HirScalarExpr|
3595         -> Result<(), ()> {
3596            if let HirScalarExpr::Parameter(..) = expr {
3597                contains_parameters = true;
3598            }
3599            Ok(())
3600        });
3601        contains_parameters
3602    }
3603}
3604
3605impl VisitChildren<Self> for HirScalarExpr {
3606    fn visit_children<F>(&self, mut f: F)
3607    where
3608        F: FnMut(&Self),
3609    {
3610        use HirScalarExpr::*;
3611        match self {
3612            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3613            CallUnary { expr, .. } => f(expr),
3614            CallBinary { expr1, expr2, .. } => {
3615                f(expr1);
3616                f(expr2);
3617            }
3618            CallVariadic { exprs, .. } => {
3619                for expr in exprs {
3620                    f(expr);
3621                }
3622            }
3623            If {
3624                cond,
3625                then,
3626                els,
3627                name: _,
3628            } => {
3629                f(cond);
3630                f(then);
3631                f(els);
3632            }
3633            Exists(..) | Select(..) => (),
3634            Windowing(expr, _name) => expr.visit_children(f),
3635        }
3636    }
3637
3638    fn visit_mut_children<F>(&mut self, mut f: F)
3639    where
3640        F: FnMut(&mut Self),
3641    {
3642        use HirScalarExpr::*;
3643        match self {
3644            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3645            CallUnary { expr, .. } => f(expr),
3646            CallBinary { expr1, expr2, .. } => {
3647                f(expr1);
3648                f(expr2);
3649            }
3650            CallVariadic { exprs, .. } => {
3651                for expr in exprs {
3652                    f(expr);
3653                }
3654            }
3655            If {
3656                cond,
3657                then,
3658                els,
3659                name: _,
3660            } => {
3661                f(cond);
3662                f(then);
3663                f(els);
3664            }
3665            Exists(..) | Select(..) => (),
3666            Windowing(expr, _name) => expr.visit_mut_children(f),
3667        }
3668    }
3669
3670    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3671    where
3672        F: FnMut(&Self) -> Result<(), E>,
3673        E: From<RecursionLimitError>,
3674    {
3675        use HirScalarExpr::*;
3676        match self {
3677            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3678            CallUnary { expr, .. } => f(expr)?,
3679            CallBinary { expr1, expr2, .. } => {
3680                f(expr1)?;
3681                f(expr2)?;
3682            }
3683            CallVariadic { exprs, .. } => {
3684                for expr in exprs {
3685                    f(expr)?;
3686                }
3687            }
3688            If {
3689                cond,
3690                then,
3691                els,
3692                name: _,
3693            } => {
3694                f(cond)?;
3695                f(then)?;
3696                f(els)?;
3697            }
3698            Exists(..) | Select(..) => (),
3699            Windowing(expr, _name) => expr.try_visit_children(f)?,
3700        }
3701        Ok(())
3702    }
3703
3704    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3705    where
3706        F: FnMut(&mut Self) -> Result<(), E>,
3707        E: From<RecursionLimitError>,
3708    {
3709        use HirScalarExpr::*;
3710        match self {
3711            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3712            CallUnary { expr, .. } => f(expr)?,
3713            CallBinary { expr1, expr2, .. } => {
3714                f(expr1)?;
3715                f(expr2)?;
3716            }
3717            CallVariadic { exprs, .. } => {
3718                for expr in exprs {
3719                    f(expr)?;
3720                }
3721            }
3722            If {
3723                cond,
3724                then,
3725                els,
3726                name: _,
3727            } => {
3728                f(cond)?;
3729                f(then)?;
3730                f(els)?;
3731            }
3732            Exists(..) | Select(..) => (),
3733            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3734        }
3735        Ok(())
3736    }
3737}
3738
3739impl AbstractExpr for HirScalarExpr {
3740    type Type = ColumnType;
3741
3742    fn typ(
3743        &self,
3744        outers: &[RelationType],
3745        inner: &RelationType,
3746        params: &BTreeMap<usize, ScalarType>,
3747    ) -> Self::Type {
3748        stack::maybe_grow(|| match self {
3749            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3750                if *level == 0 {
3751                    inner.column_types[*column].clone()
3752                } else {
3753                    outers[*level - 1].column_types[*column].clone()
3754                }
3755            }
3756            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3757            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3758            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3759            HirScalarExpr::CallUnary {
3760                expr,
3761                func,
3762                name: _,
3763            } => func.output_type(expr.typ(outers, inner, params)),
3764            HirScalarExpr::CallBinary {
3765                expr1,
3766                expr2,
3767                func,
3768                name: _,
3769            } => func.output_type(
3770                expr1.typ(outers, inner, params),
3771                expr2.typ(outers, inner, params),
3772            ),
3773            HirScalarExpr::CallVariadic {
3774                exprs,
3775                func,
3776                name: _,
3777            } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3778            HirScalarExpr::If {
3779                cond: _,
3780                then,
3781                els,
3782                name: _,
3783            } => {
3784                let then_type = then.typ(outers, inner, params);
3785                let else_type = els.typ(outers, inner, params);
3786                then_type.union(&else_type).unwrap()
3787            }
3788            HirScalarExpr::Exists(_, _name) => ScalarType::Bool.nullable(true),
3789            HirScalarExpr::Select(expr, _name) => {
3790                let mut outers = outers.to_vec();
3791                outers.insert(0, inner.clone());
3792                expr.typ(&outers, params)
3793                    .column_types
3794                    .into_element()
3795                    .nullable(true)
3796            }
3797            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3798        })
3799    }
3800}
3801
3802impl AggregateExpr {
3803    pub fn typ(
3804        &self,
3805        outers: &[RelationType],
3806        inner: &RelationType,
3807        params: &BTreeMap<usize, ScalarType>,
3808    ) -> ColumnType {
3809        self.func.output_type(self.expr.typ(outers, inner, params))
3810    }
3811
3812    /// Returns whether the expression is COUNT(*) or not.  Note that
3813    /// when we define the count builtin in sql::func, we convert
3814    /// COUNT(*) to COUNT(true), making it indistinguishable from
3815    /// literal COUNT(true), but we prefer to consider this as the
3816    /// former.
3817    ///
3818    /// (MIR has the same `is_count_asterisk`.)
3819    pub fn is_count_asterisk(&self) -> bool {
3820        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3821    }
3822}