Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25use mz_expr::func::variadic::{And, Or};
26pub use mz_expr::{
27    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
28};
29use mz_ore::collections::CollectionExt;
30use mz_ore::error::ErrorExt;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::separated;
33use mz_ore::treat_as_equal::TreatAsEqual;
34use mz_ore::{soft_assert_or_log, stack};
35use mz_repr::adt::array::ArrayDimension;
36use mz_repr::adt::numeric::NumericMaxScale;
37use mz_repr::*;
38use serde::{Deserialize, Serialize};
39
40use crate::plan::error::PlanError;
41use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
42use crate::plan::typeconv::{self, CastContext, plan_cast};
43use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
44
45use super::plan_utils::GroupSizeHints;
46
47#[allow(missing_debug_implementations)]
48pub struct Hir;
49
50impl IR for Hir {
51    type Relation = HirRelationExpr;
52    type Scalar = HirScalarExpr;
53}
54
55impl AlgExcept for Hir {
56    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
57        if *all {
58            let rhs = rhs.negate();
59            HirRelationExpr::union(lhs, rhs).threshold()
60        } else {
61            let lhs = lhs.distinct();
62            let rhs = rhs.distinct().negate();
63            HirRelationExpr::union(lhs, rhs).threshold()
64        }
65    }
66
67    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
68        let mut result = None;
69
70        use HirRelationExpr::*;
71        if let Threshold { input } = expr {
72            if let Union { base: lhs, inputs } = input.as_ref() {
73                if let [rhs] = &inputs[..] {
74                    if let Negate { input: rhs } = rhs {
75                        match (lhs.as_ref(), rhs.as_ref()) {
76                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
77                                let all = false;
78                                let lhs = lhs.as_ref();
79                                let rhs = rhs.as_ref();
80                                result = Some(Except { all, lhs, rhs })
81                            }
82                            (lhs, rhs) => {
83                                let all = true;
84                                result = Some(Except { all, lhs, rhs })
85                            }
86                        }
87                    }
88                }
89            }
90        }
91
92        result
93    }
94}
95
96#[derive(
97    Debug,
98    Clone,
99    PartialEq,
100    Eq,
101    PartialOrd,
102    Ord,
103    Hash,
104    Serialize,
105    Deserialize
106)]
107/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
108pub enum HirRelationExpr {
109    Constant {
110        rows: Vec<Row>,
111        typ: SqlRelationType,
112    },
113    Get {
114        id: mz_expr::Id,
115        typ: SqlRelationType,
116    },
117    /// Mutually recursive CTE
118    LetRec {
119        /// Maximum number of iterations to evaluate. If None, then there is no limit.
120        limit: Option<LetRecLimit>,
121        /// List of bindings all of which are in scope of each other.
122        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
123        /// Result of the AST node.
124        body: Box<HirRelationExpr>,
125    },
126    /// CTE
127    Let {
128        name: String,
129        /// The identifier to be used in `Get` variants to retrieve `value`.
130        id: mz_expr::LocalId,
131        /// The collection to be bound to `name`.
132        value: Box<HirRelationExpr>,
133        /// The result of the `Let`, evaluated with `name` bound to `value`.
134        body: Box<HirRelationExpr>,
135    },
136    Project {
137        input: Box<HirRelationExpr>,
138        outputs: Vec<usize>,
139    },
140    Map {
141        input: Box<HirRelationExpr>,
142        scalars: Vec<HirScalarExpr>,
143    },
144    CallTable {
145        func: TableFunc,
146        exprs: Vec<HirScalarExpr>,
147    },
148    Filter {
149        input: Box<HirRelationExpr>,
150        predicates: Vec<HirScalarExpr>,
151    },
152    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
153    /// joins away into more primitive exprs
154    Join {
155        left: Box<HirRelationExpr>,
156        right: Box<HirRelationExpr>,
157        on: HirScalarExpr,
158        kind: JoinKind,
159    },
160    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
161    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
162    /// rows
163    Reduce {
164        input: Box<HirRelationExpr>,
165        group_key: Vec<usize>,
166        aggregates: Vec<AggregateExpr>,
167        expected_group_size: Option<u64>,
168    },
169    Distinct {
170        input: Box<HirRelationExpr>,
171    },
172    /// Groups and orders within each group, limiting output.
173    TopK {
174        /// The source collection.
175        input: Box<HirRelationExpr>,
176        /// Column indices used to form groups.
177        group_key: Vec<usize>,
178        /// Column indices used to order rows within groups.
179        order_key: Vec<ColumnOrder>,
180        /// Number of records to retain.
181        /// It is of SqlScalarType::Int64.
182        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
183        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
184        /// is better for Postgres compat. This is because if there is a $1 here, then when external
185        /// tools `describe` the prepared statement, they discover this type. If what they find
186        /// were UInt64, then they might have trouble calling the prepared statement, because the
187        /// unsigned types are non-standard, and also don't exist even in Postgres.)
188        limit: Option<HirScalarExpr>,
189        /// Number of records to skip.
190        /// It is of SqlScalarType::Int64.
191        /// This can contain parameters at first, but by the time we reach lowering, this should
192        /// already be simply a Literal.
193        offset: HirScalarExpr,
194        /// User-supplied hint: how many rows will have the same group key.
195        expected_group_size: Option<u64>,
196    },
197    Negate {
198        input: Box<HirRelationExpr>,
199    },
200    /// Keep rows from a dataflow where the row counts are positive.
201    Threshold {
202        input: Box<HirRelationExpr>,
203    },
204    Union {
205        base: Box<HirRelationExpr>,
206        inputs: Vec<HirRelationExpr>,
207    },
208}
209
210/// Stored column metadata.
211pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
212
213#[derive(
214    Debug,
215    Clone,
216    PartialEq,
217    Eq,
218    PartialOrd,
219    Ord,
220    Hash,
221    Serialize,
222    Deserialize
223)]
224/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
225pub enum HirScalarExpr {
226    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
227    /// variable could refer to a column of the current input, or to a column of an outer relation.
228    /// We use ColumnRef to denote the difference.
229    Column(ColumnRef, NameMetadata),
230    Parameter(usize, NameMetadata),
231    Literal(Row, SqlColumnType, NameMetadata),
232    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
233    CallUnary {
234        func: UnaryFunc,
235        expr: Box<HirScalarExpr>,
236        name: NameMetadata,
237    },
238    CallBinary {
239        func: BinaryFunc,
240        expr1: Box<HirScalarExpr>,
241        expr2: Box<HirScalarExpr>,
242        name: NameMetadata,
243    },
244    CallVariadic {
245        func: VariadicFunc,
246        exprs: Vec<HirScalarExpr>,
247        name: NameMetadata,
248    },
249    If {
250        cond: Box<HirScalarExpr>,
251        then: Box<HirScalarExpr>,
252        els: Box<HirScalarExpr>,
253        name: NameMetadata,
254    },
255    /// Returns true if `expr` returns any rows
256    Exists(Box<HirRelationExpr>, NameMetadata),
257    /// Given `expr` with arity 1. If expr returns:
258    /// * 0 rows, return NULL
259    /// * 1 row, return the value of that row
260    /// * >1 rows, we return an error
261    Select(Box<HirRelationExpr>, NameMetadata),
262    Windowing(WindowExpr, NameMetadata),
263}
264
265#[derive(
266    Debug,
267    Clone,
268    PartialEq,
269    Eq,
270    PartialOrd,
271    Ord,
272    Hash,
273    Serialize,
274    Deserialize
275)]
276/// Represents the invocation of a window function over an optional partitioning with an optional
277/// order.
278pub struct WindowExpr {
279    pub func: WindowExprType,
280    pub partition_by: Vec<HirScalarExpr>,
281    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
282    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
283    ///    above,
284    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
285    /// These are separated because they are used in different places: the outer `order_by` is used
286    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
287    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
288    /// (`WindowExpr` exists only in HIR, but not in MIR.)
289    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
290    /// lowering, and not to original input columns.
291    pub order_by: Vec<HirScalarExpr>,
292}
293
294impl WindowExpr {
295    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
296    where
297        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
298    {
299        #[allow(deprecated)]
300        self.func.visit_expressions(f)?;
301        for expr in self.partition_by.iter() {
302            f(expr)?;
303        }
304        for expr in self.order_by.iter() {
305            f(expr)?;
306        }
307        Ok(())
308    }
309
310    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
311    where
312        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
313    {
314        #[allow(deprecated)]
315        self.func.visit_expressions_mut(f)?;
316        for expr in self.partition_by.iter_mut() {
317            f(expr)?;
318        }
319        for expr in self.order_by.iter_mut() {
320            f(expr)?;
321        }
322        Ok(())
323    }
324}
325
326impl VisitChildren<HirScalarExpr> for WindowExpr {
327    fn visit_children<F>(&self, mut f: F)
328    where
329        F: FnMut(&HirScalarExpr),
330    {
331        self.func.visit_children(&mut f);
332        for expr in self.partition_by.iter() {
333            f(expr);
334        }
335        for expr in self.order_by.iter() {
336            f(expr);
337        }
338    }
339
340    fn visit_mut_children<F>(&mut self, mut f: F)
341    where
342        F: FnMut(&mut HirScalarExpr),
343    {
344        self.func.visit_mut_children(&mut f);
345        for expr in self.partition_by.iter_mut() {
346            f(expr);
347        }
348        for expr in self.order_by.iter_mut() {
349            f(expr);
350        }
351    }
352
353    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
354    where
355        F: FnMut(&HirScalarExpr) -> Result<(), E>,
356        E: From<RecursionLimitError>,
357    {
358        self.func.try_visit_children(&mut f)?;
359        for expr in self.partition_by.iter() {
360            f(expr)?;
361        }
362        for expr in self.order_by.iter() {
363            f(expr)?;
364        }
365        Ok(())
366    }
367
368    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
369    where
370        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
371        E: From<RecursionLimitError>,
372    {
373        self.func.try_visit_mut_children(&mut f)?;
374        for expr in self.partition_by.iter_mut() {
375            f(expr)?;
376        }
377        for expr in self.order_by.iter_mut() {
378            f(expr)?;
379        }
380        Ok(())
381    }
382}
383
384#[derive(
385    Debug,
386    Clone,
387    PartialEq,
388    Eq,
389    PartialOrd,
390    Ord,
391    Hash,
392    Serialize,
393    Deserialize
394)]
395/// A window function with its parameters.
396///
397/// There are three types of window functions:
398/// - scalar window functions, which return a different scalar value for each
399///   row within a partition that depends exclusively on the position of the row
400///   within the partition;
401/// - value window functions, which return a scalar value for each row within a
402///   partition that might be computed based on a single row, which is usually not
403///   the current row (e.g., previous or following row; first or last row of the
404///   partition);
405/// - aggregate window functions, which compute a traditional aggregation as a
406///   window function (e.g. `sum(x) OVER (...)`).
407///   (Aggregate window  functions can in some cases be computed by joining the
408///   input relation with a reduction over the same relation that computes the
409///   aggregation using the partition key as its grouping key, but we don't
410///   automatically do this currently.)
411pub enum WindowExprType {
412    Scalar(ScalarWindowExpr),
413    Value(ValueWindowExpr),
414    Aggregate(AggregateWindowExpr),
415}
416
417impl WindowExprType {
418    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
419    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
420    where
421        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
422    {
423        #[allow(deprecated)]
424        match self {
425            Self::Scalar(expr) => expr.visit_expressions(f),
426            Self::Value(expr) => expr.visit_expressions(f),
427            Self::Aggregate(expr) => expr.visit_expressions(f),
428        }
429    }
430
431    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
432    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
433    where
434        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
435    {
436        #[allow(deprecated)]
437        match self {
438            Self::Scalar(expr) => expr.visit_expressions_mut(f),
439            Self::Value(expr) => expr.visit_expressions_mut(f),
440            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
441        }
442    }
443
444    fn typ(
445        &self,
446        outers: &[SqlRelationType],
447        inner: &SqlRelationType,
448        params: &BTreeMap<usize, SqlScalarType>,
449    ) -> SqlColumnType {
450        match self {
451            Self::Scalar(expr) => expr.typ(outers, inner, params),
452            Self::Value(expr) => expr.typ(outers, inner, params),
453            Self::Aggregate(expr) => expr.typ(outers, inner, params),
454        }
455    }
456}
457
458impl VisitChildren<HirScalarExpr> for WindowExprType {
459    fn visit_children<F>(&self, f: F)
460    where
461        F: FnMut(&HirScalarExpr),
462    {
463        match self {
464            Self::Scalar(_) => (),
465            Self::Value(expr) => expr.visit_children(f),
466            Self::Aggregate(expr) => expr.visit_children(f),
467        }
468    }
469
470    fn visit_mut_children<F>(&mut self, f: F)
471    where
472        F: FnMut(&mut HirScalarExpr),
473    {
474        match self {
475            Self::Scalar(_) => (),
476            Self::Value(expr) => expr.visit_mut_children(f),
477            Self::Aggregate(expr) => expr.visit_mut_children(f),
478        }
479    }
480
481    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
482    where
483        F: FnMut(&HirScalarExpr) -> Result<(), E>,
484        E: From<RecursionLimitError>,
485    {
486        match self {
487            Self::Scalar(_) => Ok(()),
488            Self::Value(expr) => expr.try_visit_children(f),
489            Self::Aggregate(expr) => expr.try_visit_children(f),
490        }
491    }
492
493    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
494    where
495        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
496        E: From<RecursionLimitError>,
497    {
498        match self {
499            Self::Scalar(_) => Ok(()),
500            Self::Value(expr) => expr.try_visit_mut_children(f),
501            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
502        }
503    }
504}
505
506#[derive(
507    Debug,
508    Clone,
509    PartialEq,
510    Eq,
511    PartialOrd,
512    Ord,
513    Hash,
514    Serialize,
515    Deserialize
516)]
517pub struct ScalarWindowExpr {
518    pub func: ScalarWindowFunc,
519    pub order_by: Vec<ColumnOrder>,
520}
521
522impl ScalarWindowExpr {
523    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
524    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
525    where
526        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
527    {
528        match self.func {
529            ScalarWindowFunc::RowNumber => {}
530            ScalarWindowFunc::Rank => {}
531            ScalarWindowFunc::DenseRank => {}
532        }
533        Ok(())
534    }
535
536    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
537    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
538    where
539        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
540    {
541        match self.func {
542            ScalarWindowFunc::RowNumber => {}
543            ScalarWindowFunc::Rank => {}
544            ScalarWindowFunc::DenseRank => {}
545        }
546        Ok(())
547    }
548
549    fn typ(
550        &self,
551        _outers: &[SqlRelationType],
552        _inner: &SqlRelationType,
553        _params: &BTreeMap<usize, SqlScalarType>,
554    ) -> SqlColumnType {
555        self.func.output_sql_type()
556    }
557
558    pub fn into_expr(self) -> mz_expr::AggregateFunc {
559        match self.func {
560            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
561                order_by: self.order_by,
562            },
563            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
564                order_by: self.order_by,
565            },
566            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
567                order_by: self.order_by,
568            },
569        }
570    }
571}
572
573#[derive(
574    Debug,
575    Clone,
576    PartialEq,
577    Eq,
578    PartialOrd,
579    Ord,
580    Hash,
581    Serialize,
582    Deserialize
583)]
584/// Scalar Window functions
585pub enum ScalarWindowFunc {
586    RowNumber,
587    Rank,
588    DenseRank,
589}
590
591impl Display for ScalarWindowFunc {
592    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
593        match self {
594            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
595            ScalarWindowFunc::Rank => write!(f, "rank"),
596            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
597        }
598    }
599}
600
601impl ScalarWindowFunc {
602    pub fn output_sql_type(&self) -> SqlColumnType {
603        match self {
604            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
605            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
606            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
607        }
608    }
609}
610
611#[derive(
612    Debug,
613    Clone,
614    PartialEq,
615    Eq,
616    PartialOrd,
617    Ord,
618    Hash,
619    Serialize,
620    Deserialize
621)]
622pub struct ValueWindowExpr {
623    pub func: ValueWindowFunc,
624    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
625    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
626    /// e.g., `row(#1, 3, null)`.
627    /// If it's a fused window function, then the arguments of each of the constituent function
628    /// calls are wrapped in an outer record.
629    pub args: Box<HirScalarExpr>,
630    /// See comment on `WindowExpr::order_by`.
631    pub order_by: Vec<ColumnOrder>,
632    pub window_frame: WindowFrame,
633    pub ignore_nulls: bool,
634}
635
636impl Display for ValueWindowFunc {
637    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
638        match self {
639            ValueWindowFunc::Lag => write!(f, "lag"),
640            ValueWindowFunc::Lead => write!(f, "lead"),
641            ValueWindowFunc::FirstValue => write!(f, "first_value"),
642            ValueWindowFunc::LastValue => write!(f, "last_value"),
643            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
644        }
645    }
646}
647
648impl ValueWindowExpr {
649    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
650    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
651    where
652        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
653    {
654        f(&self.args)
655    }
656
657    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
658    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
659    where
660        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
661    {
662        f(&mut self.args)
663    }
664
665    fn typ(
666        &self,
667        outers: &[SqlRelationType],
668        inner: &SqlRelationType,
669        params: &BTreeMap<usize, SqlScalarType>,
670    ) -> SqlColumnType {
671        self.func
672            .output_sql_type(self.args.typ(outers, inner, params))
673    }
674
675    /// Converts into `mz_expr::AggregateFunc`.
676    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
677        (
678            self.args,
679            self.func
680                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
681        )
682    }
683}
684
685impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
686    fn visit_children<F>(&self, mut f: F)
687    where
688        F: FnMut(&HirScalarExpr),
689    {
690        f(&self.args)
691    }
692
693    fn visit_mut_children<F>(&mut self, mut f: F)
694    where
695        F: FnMut(&mut HirScalarExpr),
696    {
697        f(&mut self.args)
698    }
699
700    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
701    where
702        F: FnMut(&HirScalarExpr) -> Result<(), E>,
703        E: From<RecursionLimitError>,
704    {
705        f(&self.args)
706    }
707
708    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
709    where
710        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
711        E: From<RecursionLimitError>,
712    {
713        f(&mut self.args)
714    }
715}
716
717#[derive(
718    Debug,
719    Clone,
720    PartialEq,
721    Eq,
722    PartialOrd,
723    Ord,
724    Hash,
725    Serialize,
726    Deserialize
727)]
728/// Value Window functions
729pub enum ValueWindowFunc {
730    Lag,
731    Lead,
732    FirstValue,
733    LastValue,
734    Fused(Vec<ValueWindowFunc>),
735}
736
737impl ValueWindowFunc {
738    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
739        match self {
740            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
741                // The input is a (value, offset, default) record, so extract the type of the first arg
742                input_type.scalar_type.unwrap_record_element_type()[0]
743                    .clone()
744                    .nullable(true)
745            }
746            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
747                input_type.scalar_type.nullable(true)
748            }
749            ValueWindowFunc::Fused(funcs) => {
750                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
751                SqlScalarType::Record {
752                    fields: funcs
753                        .iter()
754                        .zip_eq(input_types)
755                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
756                        .collect(),
757                    custom_id: None,
758                }
759                .nullable(false)
760            }
761        }
762    }
763
764    pub fn into_expr(
765        self,
766        order_by: Vec<ColumnOrder>,
767        window_frame: WindowFrame,
768        ignore_nulls: bool,
769    ) -> mz_expr::AggregateFunc {
770        match self {
771            // Lag and Lead are fundamentally the same function, just with opposite directions
772            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
773                order_by,
774                lag_lead: mz_expr::LagLeadType::Lag,
775                ignore_nulls,
776            },
777            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
778                order_by,
779                lag_lead: mz_expr::LagLeadType::Lead,
780                ignore_nulls,
781            },
782            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
783                order_by,
784                window_frame,
785            },
786            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
787                order_by,
788                window_frame,
789            },
790            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
791                funcs: funcs
792                    .into_iter()
793                    .map(|func| {
794                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
795                    })
796                    .collect(),
797                order_by,
798            },
799        }
800    }
801}
802
803#[derive(
804    Debug,
805    Clone,
806    PartialEq,
807    Eq,
808    PartialOrd,
809    Ord,
810    Hash,
811    Serialize,
812    Deserialize
813)]
814pub struct AggregateWindowExpr {
815    pub aggregate_expr: AggregateExpr,
816    pub order_by: Vec<ColumnOrder>,
817    pub window_frame: WindowFrame,
818}
819
820impl AggregateWindowExpr {
821    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
822    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
823    where
824        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
825    {
826        f(&self.aggregate_expr.expr)
827    }
828
829    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
830    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
831    where
832        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
833    {
834        f(&mut self.aggregate_expr.expr)
835    }
836
837    fn typ(
838        &self,
839        outers: &[SqlRelationType],
840        inner: &SqlRelationType,
841        params: &BTreeMap<usize, SqlScalarType>,
842    ) -> SqlColumnType {
843        self.aggregate_expr
844            .func
845            .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params))
846    }
847
848    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
849        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
850            (
851                self.aggregate_expr.expr,
852                FusedWindowAggregate {
853                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
854                    order_by: self.order_by,
855                    window_frame: self.window_frame,
856                },
857            )
858        } else {
859            (
860                self.aggregate_expr.expr,
861                WindowAggregate {
862                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
863                    order_by: self.order_by,
864                    window_frame: self.window_frame,
865                },
866            )
867        }
868    }
869}
870
871impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
872    fn visit_children<F>(&self, mut f: F)
873    where
874        F: FnMut(&HirScalarExpr),
875    {
876        f(&self.aggregate_expr.expr)
877    }
878
879    fn visit_mut_children<F>(&mut self, mut f: F)
880    where
881        F: FnMut(&mut HirScalarExpr),
882    {
883        f(&mut self.aggregate_expr.expr)
884    }
885
886    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
887    where
888        F: FnMut(&HirScalarExpr) -> Result<(), E>,
889        E: From<RecursionLimitError>,
890    {
891        f(&self.aggregate_expr.expr)
892    }
893
894    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
895    where
896        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
897        E: From<RecursionLimitError>,
898    {
899        f(&mut self.aggregate_expr.expr)
900    }
901}
902
903/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
904/// determined. Several SQL expressions can be freely coerced based upon where
905/// in the expression tree they appear. For example, the string literal '42'
906/// will be automatically coerced to the integer 42 if used in a numeric
907/// context:
908///
909/// ```sql
910/// SELECT '42' + 42
911/// ```
912///
913/// This separate type gives the code that needs to interact with coercions very
914/// fine-grained control over what coercions happen and when.
915///
916/// The primary driver of coercion is function and operator selection, as
917/// choosing the correct function or operator implementation depends on the type
918/// of the provided arguments. Coercion also occurs at the very root of the
919/// scalar expression tree. For example in
920///
921/// ```sql
922/// SELECT ... WHERE $1
923/// ```
924///
925/// the `WHERE` clause will coerce the contained unconstrained type parameter
926/// `$1` to have type bool.
927#[derive(Clone, Debug)]
928pub enum CoercibleScalarExpr {
929    Coerced(HirScalarExpr),
930    Parameter(usize),
931    LiteralNull,
932    LiteralString(String),
933    LiteralRecord(Vec<CoercibleScalarExpr>),
934}
935
936impl CoercibleScalarExpr {
937    pub fn type_as(
938        self,
939        ecx: &ExprContext,
940        ty: &SqlScalarType,
941    ) -> Result<HirScalarExpr, PlanError> {
942        let expr = typeconv::plan_coerce(ecx, self, ty)?;
943        let expr_ty = ecx.scalar_type(&expr);
944        if ty != &expr_ty {
945            sql_bail!(
946                "{} must have type {}, not type {}",
947                ecx.name,
948                ecx.humanize_scalar_type(ty, false),
949                ecx.humanize_scalar_type(&expr_ty, false),
950            );
951        }
952        Ok(expr)
953    }
954
955    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
956        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
957    }
958
959    pub fn cast_to(
960        self,
961        ecx: &ExprContext,
962        ccx: CastContext,
963        ty: &SqlScalarType,
964    ) -> Result<HirScalarExpr, PlanError> {
965        let expr = typeconv::plan_coerce(ecx, self, ty)?;
966        typeconv::plan_cast(ecx, ccx, expr, ty)
967    }
968}
969
970/// The column type for a [`CoercibleScalarExpr`].
971#[derive(Clone, Debug)]
972pub enum CoercibleColumnType {
973    Coerced(SqlColumnType),
974    Record(Vec<CoercibleColumnType>),
975    Uncoerced,
976}
977
978impl CoercibleColumnType {
979    /// Reports the nullability of the type.
980    pub fn nullable(&self) -> bool {
981        match self {
982            // A coerced value's nullability is known.
983            CoercibleColumnType::Coerced(ct) => ct.nullable,
984
985            // A literal record can never be null.
986            CoercibleColumnType::Record(_) => false,
987
988            // An uncoerced literal may be the literal `NULL`, so we have
989            // to conservatively assume it is nullable.
990            CoercibleColumnType::Uncoerced => true,
991        }
992    }
993}
994
995/// The scalar type for a [`CoercibleScalarExpr`].
996#[derive(Clone, Debug)]
997pub enum CoercibleScalarType {
998    Coerced(SqlScalarType),
999    Record(Vec<CoercibleColumnType>),
1000    Uncoerced,
1001}
1002
1003impl CoercibleScalarType {
1004    /// Reports whether the scalar type has been coerced.
1005    pub fn is_coerced(&self) -> bool {
1006        matches!(self, CoercibleScalarType::Coerced(_))
1007    }
1008
1009    /// Returns the coerced scalar type, if the type is coerced.
1010    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1011        match self {
1012            CoercibleScalarType::Coerced(t) => Some(t),
1013            _ => None,
1014        }
1015    }
1016
1017    /// If the type is coerced, apply the mapping function to the contained
1018    /// scalar type.
1019    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1020    where
1021        F: FnOnce(SqlScalarType) -> SqlScalarType,
1022    {
1023        match self {
1024            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1025            _ => self,
1026        }
1027    }
1028
1029    /// If the type is an coercible record, forcibly converts to a coerced
1030    /// record type. Any uncoerced field types are assumed to be of type text.
1031    ///
1032    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
1033    /// accepts a type hint that can indicate the types of uncoerced field
1034    /// types.
1035    pub fn force_coerced_if_record(&mut self) {
1036        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1037            let mut fields = vec![];
1038            for (i, uf) in uncoerced_fields.enumerate() {
1039                let name = ColumnName::from(format!("f{}", i + 1));
1040                let ty = match uf {
1041                    CoercibleColumnType::Coerced(ty) => ty,
1042                    CoercibleColumnType::Record(mut fields) => {
1043                        convert(fields.drain(..)).nullable(false)
1044                    }
1045                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1046                };
1047                fields.push((name, ty))
1048            }
1049            SqlScalarType::Record {
1050                fields: fields.into(),
1051                custom_id: None,
1052            }
1053        }
1054
1055        if let CoercibleScalarType::Record(fields) = self {
1056            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1057        }
1058    }
1059}
1060
1061/// An expression whose type can be ascertained.
1062///
1063/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
1064pub trait AbstractExpr {
1065    type Type: AbstractColumnType;
1066
1067    /// Computes the type of the expression.
1068    fn typ(
1069        &self,
1070        outers: &[SqlRelationType],
1071        inner: &SqlRelationType,
1072        params: &BTreeMap<usize, SqlScalarType>,
1073    ) -> Self::Type;
1074}
1075
1076impl AbstractExpr for CoercibleScalarExpr {
1077    type Type = CoercibleColumnType;
1078
1079    fn typ(
1080        &self,
1081        outers: &[SqlRelationType],
1082        inner: &SqlRelationType,
1083        params: &BTreeMap<usize, SqlScalarType>,
1084    ) -> Self::Type {
1085        match self {
1086            CoercibleScalarExpr::Coerced(expr) => {
1087                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1088            }
1089            CoercibleScalarExpr::LiteralRecord(scalars) => {
1090                let fields = scalars
1091                    .iter()
1092                    .map(|s| s.typ(outers, inner, params))
1093                    .collect();
1094                CoercibleColumnType::Record(fields)
1095            }
1096            _ => CoercibleColumnType::Uncoerced,
1097        }
1098    }
1099}
1100
1101/// A column type-like object whose underlying scalar type-like object can be
1102/// ascertained.
1103///
1104/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1105pub trait AbstractColumnType {
1106    type AbstractScalarType;
1107
1108    /// Converts the column type-like object into its inner scalar type-like
1109    /// object.
1110    fn scalar_type(self) -> Self::AbstractScalarType;
1111}
1112
1113impl AbstractColumnType for SqlColumnType {
1114    type AbstractScalarType = SqlScalarType;
1115
1116    fn scalar_type(self) -> Self::AbstractScalarType {
1117        self.scalar_type
1118    }
1119}
1120
1121impl AbstractColumnType for CoercibleColumnType {
1122    type AbstractScalarType = CoercibleScalarType;
1123
1124    fn scalar_type(self) -> Self::AbstractScalarType {
1125        match self {
1126            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1127            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1128            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1129        }
1130    }
1131}
1132
1133impl From<HirScalarExpr> for CoercibleScalarExpr {
1134    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1135        CoercibleScalarExpr::Coerced(expr)
1136    }
1137}
1138
1139/// A leveled column reference.
1140///
1141/// In the course of decorrelation, multiple levels of nested subqueries are
1142/// traversed, and references to columns may correspond to different levels
1143/// of containing outer subqueries.
1144///
1145/// A `ColumnRef` allows expressions to refer to columns while being clear
1146/// about which level the column references without manually performing the
1147/// bookkeeping tracking their actual column locations.
1148///
1149/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1150/// from the reference, using `column` as a unique identifier in that subquery level.
1151/// A `level` of zero corresponds to the current scope, and levels increase to
1152/// indicate subqueries further "outwards".
1153#[derive(
1154    Debug,
1155    Clone,
1156    Copy,
1157    PartialEq,
1158    Eq,
1159    Hash,
1160    Ord,
1161    PartialOrd,
1162    Serialize,
1163    Deserialize
1164)]
1165pub struct ColumnRef {
1166    // scope level, where 0 is the current scope and 1+ are outer scopes.
1167    pub level: usize,
1168    // level-local column identifier used.
1169    pub column: usize,
1170}
1171
1172#[derive(
1173    Debug,
1174    Clone,
1175    PartialEq,
1176    Eq,
1177    PartialOrd,
1178    Ord,
1179    Hash,
1180    Serialize,
1181    Deserialize
1182)]
1183pub enum JoinKind {
1184    Inner,
1185    LeftOuter,
1186    RightOuter,
1187    FullOuter,
1188}
1189
1190impl fmt::Display for JoinKind {
1191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1192        write!(
1193            f,
1194            "{}",
1195            match self {
1196                JoinKind::Inner => "Inner",
1197                JoinKind::LeftOuter => "LeftOuter",
1198                JoinKind::RightOuter => "RightOuter",
1199                JoinKind::FullOuter => "FullOuter",
1200            }
1201        )
1202    }
1203}
1204
1205impl JoinKind {
1206    pub fn can_be_correlated(&self) -> bool {
1207        match self {
1208            JoinKind::Inner | JoinKind::LeftOuter => true,
1209            JoinKind::RightOuter | JoinKind::FullOuter => false,
1210        }
1211    }
1212
1213    pub fn can_elide_identity_left_join(&self) -> bool {
1214        match self {
1215            JoinKind::Inner | JoinKind::RightOuter => true,
1216            JoinKind::LeftOuter | JoinKind::FullOuter => false,
1217        }
1218    }
1219
1220    pub fn can_elide_identity_right_join(&self) -> bool {
1221        match self {
1222            JoinKind::Inner | JoinKind::LeftOuter => true,
1223            JoinKind::RightOuter | JoinKind::FullOuter => false,
1224        }
1225    }
1226}
1227
1228#[derive(
1229    Debug,
1230    Clone,
1231    PartialEq,
1232    Eq,
1233    PartialOrd,
1234    Ord,
1235    Hash,
1236    Serialize,
1237    Deserialize
1238)]
1239pub struct AggregateExpr {
1240    pub func: AggregateFunc,
1241    pub expr: Box<HirScalarExpr>,
1242    pub distinct: bool,
1243}
1244
1245/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1246/// types may be different.
1247///
1248/// Specifically, the nullability of the aggregate columns is more common
1249/// here than in `expr`, as these aggregates may be applied over empty
1250/// result sets and should be null in those cases, whereas `expr` variants
1251/// only return null values when supplied nulls as input.
1252#[derive(
1253    Clone,
1254    Debug,
1255    Eq,
1256    PartialEq,
1257    PartialOrd,
1258    Ord,
1259    Hash,
1260    Serialize,
1261    Deserialize
1262)]
1263pub enum AggregateFunc {
1264    MaxNumeric,
1265    MaxInt16,
1266    MaxInt32,
1267    MaxInt64,
1268    MaxUInt16,
1269    MaxUInt32,
1270    MaxUInt64,
1271    MaxMzTimestamp,
1272    MaxFloat32,
1273    MaxFloat64,
1274    MaxBool,
1275    MaxString,
1276    MaxDate,
1277    MaxTimestamp,
1278    MaxTimestampTz,
1279    MaxInterval,
1280    MaxTime,
1281    MinNumeric,
1282    MinInt16,
1283    MinInt32,
1284    MinInt64,
1285    MinUInt16,
1286    MinUInt32,
1287    MinUInt64,
1288    MinMzTimestamp,
1289    MinFloat32,
1290    MinFloat64,
1291    MinBool,
1292    MinString,
1293    MinDate,
1294    MinTimestamp,
1295    MinTimestampTz,
1296    MinInterval,
1297    MinTime,
1298    SumInt16,
1299    SumInt32,
1300    SumInt64,
1301    SumUInt16,
1302    SumUInt32,
1303    SumUInt64,
1304    SumFloat32,
1305    SumFloat64,
1306    SumNumeric,
1307    Count,
1308    Any,
1309    All,
1310    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1311    /// into a JSON list. The other elements are columns used by `order_by`.
1312    ///
1313    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1314    /// layer, this function filters out `Datum::Null`, for consistency with
1315    /// the other aggregate functions.
1316    JsonbAgg {
1317        order_by: Vec<ColumnOrder>,
1318    },
1319    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1320    /// JSON map. The other elements are columns used by `order_by`.
1321    JsonbObjectAgg {
1322        order_by: Vec<ColumnOrder>,
1323    },
1324    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1325    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1326    /// elements are columns used by `order_by`.
1327    MapAgg {
1328        order_by: Vec<ColumnOrder>,
1329        value_type: SqlScalarType,
1330    },
1331    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1332    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1333    ArrayConcat {
1334        order_by: Vec<ColumnOrder>,
1335    },
1336    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1337    /// single `Datum::List`. The other elements are columns used by `order_by`.
1338    ListConcat {
1339        order_by: Vec<ColumnOrder>,
1340    },
1341    StringAgg {
1342        order_by: Vec<ColumnOrder>,
1343    },
1344    /// A bundle of fused window aggregations: its input is a record, whose each
1345    /// component will be the input to one of the `AggregateFunc`s.
1346    ///
1347    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1348    /// more specifically an `AggregateWindowExpr`.
1349    FusedWindowAgg {
1350        funcs: Vec<AggregateFunc>,
1351    },
1352    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1353    ///
1354    /// Useful for removing an expensive aggregation while maintaining the shape
1355    /// of a reduce operator.
1356    Dummy,
1357}
1358
1359impl AggregateFunc {
1360    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1361    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1362        match self {
1363            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1364            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1365            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1366            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1367            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1368            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1369            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1370            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1371            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1372            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1373            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1374            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1375            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1376            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1377            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1378            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1379            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1380            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1381            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1382            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1383            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1384            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1385            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1386            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1387            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1388            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1389            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1390            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1391            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1392            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1393            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1394            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1395            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1396            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1397            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1398            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1399            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1400            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1401            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1402            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1403            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1404            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1405            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1406            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1407            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1408            AggregateFunc::All => mz_expr::AggregateFunc::All,
1409            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1410            AggregateFunc::JsonbObjectAgg { order_by } => {
1411                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1412            }
1413            AggregateFunc::MapAgg {
1414                order_by,
1415                value_type,
1416            } => mz_expr::AggregateFunc::MapAgg {
1417                order_by,
1418                value_type,
1419            },
1420            AggregateFunc::ArrayConcat { order_by } => {
1421                mz_expr::AggregateFunc::ArrayConcat { order_by }
1422            }
1423            AggregateFunc::ListConcat { order_by } => {
1424                mz_expr::AggregateFunc::ListConcat { order_by }
1425            }
1426            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1427            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1428            // `AggregateWindowExpr::into_expr`.
1429            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1430                panic!("into_expr called on FusedWindowAgg")
1431            }
1432            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1433        }
1434    }
1435
1436    /// Returns a datum whose inclusion in the aggregation will not change its
1437    /// result.
1438    ///
1439    /// # Panics
1440    ///
1441    /// Panics if called on a `FusedWindowAgg`.
1442    pub fn identity_datum(&self) -> Datum<'static> {
1443        match self {
1444            AggregateFunc::Any => Datum::False,
1445            AggregateFunc::All => Datum::True,
1446            AggregateFunc::Dummy => Datum::Dummy,
1447            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1448            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1449            AggregateFunc::MaxNumeric
1450            | AggregateFunc::MaxInt16
1451            | AggregateFunc::MaxInt32
1452            | AggregateFunc::MaxInt64
1453            | AggregateFunc::MaxUInt16
1454            | AggregateFunc::MaxUInt32
1455            | AggregateFunc::MaxUInt64
1456            | AggregateFunc::MaxMzTimestamp
1457            | AggregateFunc::MaxFloat32
1458            | AggregateFunc::MaxFloat64
1459            | AggregateFunc::MaxBool
1460            | AggregateFunc::MaxString
1461            | AggregateFunc::MaxDate
1462            | AggregateFunc::MaxTimestamp
1463            | AggregateFunc::MaxTimestampTz
1464            | AggregateFunc::MaxInterval
1465            | AggregateFunc::MaxTime
1466            | AggregateFunc::MinNumeric
1467            | AggregateFunc::MinInt16
1468            | AggregateFunc::MinInt32
1469            | AggregateFunc::MinInt64
1470            | AggregateFunc::MinUInt16
1471            | AggregateFunc::MinUInt32
1472            | AggregateFunc::MinUInt64
1473            | AggregateFunc::MinMzTimestamp
1474            | AggregateFunc::MinFloat32
1475            | AggregateFunc::MinFloat64
1476            | AggregateFunc::MinBool
1477            | AggregateFunc::MinString
1478            | AggregateFunc::MinDate
1479            | AggregateFunc::MinTimestamp
1480            | AggregateFunc::MinTimestampTz
1481            | AggregateFunc::MinInterval
1482            | AggregateFunc::MinTime
1483            | AggregateFunc::SumInt16
1484            | AggregateFunc::SumInt32
1485            | AggregateFunc::SumInt64
1486            | AggregateFunc::SumUInt16
1487            | AggregateFunc::SumUInt32
1488            | AggregateFunc::SumUInt64
1489            | AggregateFunc::SumFloat32
1490            | AggregateFunc::SumFloat64
1491            | AggregateFunc::SumNumeric
1492            | AggregateFunc::Count
1493            | AggregateFunc::JsonbAgg { .. }
1494            | AggregateFunc::JsonbObjectAgg { .. }
1495            | AggregateFunc::MapAgg { .. }
1496            | AggregateFunc::StringAgg { .. } => Datum::Null,
1497            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1498                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1499                // in HIR planning, because it is introduced only during HIR transformation.
1500                //
1501                // The implementation could be something like the following, except that we need to
1502                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1503                // ```
1504                // let temp_storage = RowArena::new();
1505                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1506                // ```
1507                panic!("FusedWindowAgg doesn't have an identity_datum")
1508            }
1509        }
1510    }
1511
1512    /// The output column type for the result of an aggregation.
1513    ///
1514    /// The output column type also contains nullability information, which
1515    /// is (without further information) true for aggregations that are not
1516    /// counts.
1517    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1518        let scalar_type = match self {
1519            AggregateFunc::Count => SqlScalarType::Int64,
1520            AggregateFunc::Any => SqlScalarType::Bool,
1521            AggregateFunc::All => SqlScalarType::Bool,
1522            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1523            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1524            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1525            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1526            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1527                max_scale: Some(NumericMaxScale::ZERO),
1528            },
1529            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1530            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1531                max_scale: Some(NumericMaxScale::ZERO),
1532            },
1533            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1534                value_type: Box::new(value_type.clone()),
1535                custom_id: None,
1536            },
1537            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1538                match input_type.scalar_type {
1539                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1540                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1541                    _ => unreachable!(),
1542                }
1543            }
1544            AggregateFunc::MaxNumeric
1545            | AggregateFunc::MaxInt16
1546            | AggregateFunc::MaxInt32
1547            | AggregateFunc::MaxInt64
1548            | AggregateFunc::MaxUInt16
1549            | AggregateFunc::MaxUInt32
1550            | AggregateFunc::MaxUInt64
1551            | AggregateFunc::MaxMzTimestamp
1552            | AggregateFunc::MaxFloat32
1553            | AggregateFunc::MaxFloat64
1554            | AggregateFunc::MaxBool
1555            | AggregateFunc::MaxString
1556            | AggregateFunc::MaxDate
1557            | AggregateFunc::MaxTimestamp
1558            | AggregateFunc::MaxTimestampTz
1559            | AggregateFunc::MaxInterval
1560            | AggregateFunc::MaxTime
1561            | AggregateFunc::MinNumeric
1562            | AggregateFunc::MinInt16
1563            | AggregateFunc::MinInt32
1564            | AggregateFunc::MinInt64
1565            | AggregateFunc::MinUInt16
1566            | AggregateFunc::MinUInt32
1567            | AggregateFunc::MinUInt64
1568            | AggregateFunc::MinMzTimestamp
1569            | AggregateFunc::MinFloat32
1570            | AggregateFunc::MinFloat64
1571            | AggregateFunc::MinBool
1572            | AggregateFunc::MinString
1573            | AggregateFunc::MinDate
1574            | AggregateFunc::MinTimestamp
1575            | AggregateFunc::MinTimestampTz
1576            | AggregateFunc::MinInterval
1577            | AggregateFunc::MinTime
1578            | AggregateFunc::SumFloat32
1579            | AggregateFunc::SumFloat64
1580            | AggregateFunc::SumNumeric
1581            | AggregateFunc::Dummy => input_type.scalar_type,
1582            AggregateFunc::FusedWindowAgg { funcs } => {
1583                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1584                SqlScalarType::Record {
1585                    fields: funcs
1586                        .iter()
1587                        .zip_eq(input_types)
1588                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
1589                        .collect(),
1590                    custom_id: None,
1591                }
1592            }
1593        };
1594        // max/min/sum return null on empty sets
1595        let nullable = !matches!(self, AggregateFunc::Count);
1596        scalar_type.nullable(nullable)
1597    }
1598
1599    pub fn is_order_sensitive(&self) -> bool {
1600        use AggregateFunc::*;
1601        matches!(
1602            self,
1603            JsonbAgg { .. }
1604                | JsonbObjectAgg { .. }
1605                | MapAgg { .. }
1606                | ArrayConcat { .. }
1607                | ListConcat { .. }
1608                | StringAgg { .. }
1609        )
1610    }
1611}
1612
1613impl HirRelationExpr {
1614    /// Gets the SQL type of a self-contained, top-level expression.
1615    pub fn top_level_typ(&self) -> SqlRelationType {
1616        self.typ(&[], &BTreeMap::new())
1617    }
1618
1619    /// Gets the SQL type of the expression.
1620    ///
1621    /// `outers` gives types for outer relations.
1622    /// `params` gives types for parameters.
1623    pub fn typ(
1624        &self,
1625        outers: &[SqlRelationType],
1626        params: &BTreeMap<usize, SqlScalarType>,
1627    ) -> SqlRelationType {
1628        stack::maybe_grow(|| match self {
1629            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1630            HirRelationExpr::Get { typ, .. } => typ.clone(),
1631            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1632            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1633            HirRelationExpr::Project { input, outputs } => {
1634                let input_typ = input.typ(outers, params);
1635                SqlRelationType::new(
1636                    outputs
1637                        .iter()
1638                        .map(|&i| input_typ.column_types[i].clone())
1639                        .collect(),
1640                )
1641            }
1642            HirRelationExpr::Map { input, scalars } => {
1643                let mut typ = input.typ(outers, params);
1644                for scalar in scalars {
1645                    typ.column_types.push(scalar.typ(outers, &typ, params));
1646                }
1647                typ
1648            }
1649            HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(),
1650            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1651                input.typ(outers, params)
1652            }
1653            HirRelationExpr::Join {
1654                left, right, kind, ..
1655            } => {
1656                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1657                let right_nullable =
1658                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1659                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1660                    let nullable = t.nullable || left_nullable;
1661                    t.nullable(nullable)
1662                });
1663                let mut outers = outers.to_vec();
1664                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1665                let rt = right
1666                    .typ(&outers, params)
1667                    .column_types
1668                    .into_iter()
1669                    .map(|t| {
1670                        let nullable = t.nullable || right_nullable;
1671                        t.nullable(nullable)
1672                    });
1673                SqlRelationType::new(lt.chain(rt).collect())
1674            }
1675            HirRelationExpr::Reduce {
1676                input,
1677                group_key,
1678                aggregates,
1679                expected_group_size: _,
1680            } => {
1681                let input_typ = input.typ(outers, params);
1682                let mut column_types = group_key
1683                    .iter()
1684                    .map(|&i| input_typ.column_types[i].clone())
1685                    .collect::<Vec<_>>();
1686                for agg in aggregates {
1687                    column_types.push(agg.typ(outers, &input_typ, params));
1688                }
1689                // TODO(frank): add primary key information.
1690                SqlRelationType::new(column_types)
1691            }
1692            // TODO(frank): check for removal; add primary key information.
1693            HirRelationExpr::Distinct { input }
1694            | HirRelationExpr::Negate { input }
1695            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1696            HirRelationExpr::Union { base, inputs } => {
1697                let mut base_cols = base.typ(outers, params).column_types;
1698                for input in inputs {
1699                    for (base_col, col) in base_cols
1700                        .iter_mut()
1701                        .zip_eq(input.typ(outers, params).column_types)
1702                    {
1703                        *base_col = base_col.sql_union(&col).unwrap(); // HIR deliberately not using `union`
1704                    }
1705                }
1706                SqlRelationType::new(base_cols)
1707            }
1708        })
1709    }
1710
1711    pub fn arity(&self) -> usize {
1712        match self {
1713            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1714            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1715            HirRelationExpr::Let { body, .. } => body.arity(),
1716            HirRelationExpr::LetRec { body, .. } => body.arity(),
1717            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1718            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1719            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1720            HirRelationExpr::Filter { input, .. }
1721            | HirRelationExpr::TopK { input, .. }
1722            | HirRelationExpr::Distinct { input }
1723            | HirRelationExpr::Negate { input }
1724            | HirRelationExpr::Threshold { input } => input.arity(),
1725            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1726            HirRelationExpr::Union { base, .. } => base.arity(),
1727            HirRelationExpr::Reduce {
1728                group_key,
1729                aggregates,
1730                ..
1731            } => group_key.len() + aggregates.len(),
1732        }
1733    }
1734
1735    /// If self is a constant, return the value and the type, otherwise `None`.
1736    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1737        match self {
1738            Self::Constant { rows, typ } => Some((rows, typ)),
1739            _ => None,
1740        }
1741    }
1742
1743    /// Reports whether this expression contains a column reference to its
1744    /// direct parent scope.
1745    pub fn is_correlated(&self) -> bool {
1746        let mut correlated = false;
1747        #[allow(deprecated)]
1748        self.visit_columns(0, &mut |depth, col| {
1749            if col.level > depth && col.level - depth == 1 {
1750                correlated = true;
1751            }
1752        });
1753        correlated
1754    }
1755
1756    pub fn is_join_identity(&self) -> bool {
1757        match self {
1758            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1759            _ => false,
1760        }
1761    }
1762
1763    pub fn project(self, outputs: Vec<usize>) -> Self {
1764        if outputs.iter().copied().eq(0..self.arity()) {
1765            // The projection is trivial. Suppress it.
1766            self
1767        } else {
1768            HirRelationExpr::Project {
1769                input: Box::new(self),
1770                outputs,
1771            }
1772        }
1773    }
1774
1775    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1776        if scalars.is_empty() {
1777            // The map is trivial. Suppress it.
1778            self
1779        } else if let HirRelationExpr::Map {
1780            scalars: old_scalars,
1781            input: _,
1782        } = &mut self
1783        {
1784            // Map applied to a map. Fuse the maps.
1785            old_scalars.extend(scalars);
1786            self
1787        } else {
1788            HirRelationExpr::Map {
1789                input: Box::new(self),
1790                scalars,
1791            }
1792        }
1793    }
1794
1795    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1796        if let HirRelationExpr::Filter {
1797            input: _,
1798            predicates,
1799        } = &mut self
1800        {
1801            predicates.extend(preds);
1802            predicates.sort();
1803            predicates.dedup();
1804            self
1805        } else {
1806            preds.sort();
1807            preds.dedup();
1808            HirRelationExpr::Filter {
1809                input: Box::new(self),
1810                predicates: preds,
1811            }
1812        }
1813    }
1814
1815    pub fn reduce(
1816        self,
1817        group_key: Vec<usize>,
1818        aggregates: Vec<AggregateExpr>,
1819        expected_group_size: Option<u64>,
1820    ) -> Self {
1821        HirRelationExpr::Reduce {
1822            input: Box::new(self),
1823            group_key,
1824            aggregates,
1825            expected_group_size,
1826        }
1827    }
1828
1829    pub fn top_k(
1830        self,
1831        group_key: Vec<usize>,
1832        order_key: Vec<ColumnOrder>,
1833        limit: Option<HirScalarExpr>,
1834        offset: HirScalarExpr,
1835        expected_group_size: Option<u64>,
1836    ) -> Self {
1837        HirRelationExpr::TopK {
1838            input: Box::new(self),
1839            group_key,
1840            order_key,
1841            limit,
1842            offset,
1843            expected_group_size,
1844        }
1845    }
1846
1847    pub fn negate(self) -> Self {
1848        if let HirRelationExpr::Negate { input } = self {
1849            *input
1850        } else {
1851            HirRelationExpr::Negate {
1852                input: Box::new(self),
1853            }
1854        }
1855    }
1856
1857    pub fn distinct(self) -> Self {
1858        if let HirRelationExpr::Distinct { .. } = self {
1859            self
1860        } else {
1861            HirRelationExpr::Distinct {
1862                input: Box::new(self),
1863            }
1864        }
1865    }
1866
1867    pub fn threshold(self) -> Self {
1868        if let HirRelationExpr::Threshold { .. } = self {
1869            self
1870        } else {
1871            HirRelationExpr::Threshold {
1872                input: Box::new(self),
1873            }
1874        }
1875    }
1876
1877    pub fn union(self, other: Self) -> Self {
1878        let mut terms = Vec::new();
1879        if let HirRelationExpr::Union { base, inputs } = self {
1880            terms.push(*base);
1881            terms.extend(inputs);
1882        } else {
1883            terms.push(self);
1884        }
1885        if let HirRelationExpr::Union { base, inputs } = other {
1886            terms.push(*base);
1887            terms.extend(inputs);
1888        } else {
1889            terms.push(other);
1890        }
1891        HirRelationExpr::Union {
1892            base: Box::new(terms.remove(0)),
1893            inputs: terms,
1894        }
1895    }
1896
1897    pub fn exists(self) -> HirScalarExpr {
1898        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1899    }
1900
1901    pub fn select(self) -> HirScalarExpr {
1902        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1903    }
1904
1905    pub fn join(
1906        self,
1907        mut right: HirRelationExpr,
1908        on: HirScalarExpr,
1909        kind: JoinKind,
1910    ) -> HirRelationExpr {
1911        if self.is_join_identity()
1912            && !right.is_correlated()
1913            && on == HirScalarExpr::literal_true()
1914            && kind.can_elide_identity_left_join()
1915        {
1916            // The join can be elided, but we need to adjust column references
1917            // on the right-hand side to account for the removal of the scope
1918            // introduced by the join.
1919            #[allow(deprecated)]
1920            right.visit_columns_mut(0, &mut |depth, col| {
1921                if col.level > depth {
1922                    col.level -= 1;
1923                }
1924            });
1925            right
1926        } else if right.is_join_identity()
1927            && on == HirScalarExpr::literal_true()
1928            && kind.can_elide_identity_right_join()
1929        {
1930            self
1931        } else {
1932            HirRelationExpr::Join {
1933                left: Box::new(self),
1934                right: Box::new(right),
1935                on,
1936                kind,
1937            }
1938        }
1939    }
1940
1941    pub fn take(&mut self) -> HirRelationExpr {
1942        mem::replace(
1943            self,
1944            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1945        )
1946    }
1947
1948    #[deprecated = "Use `Visit::visit_post`."]
1949    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1950    where
1951        F: FnMut(&'a Self, usize),
1952    {
1953        #[allow(deprecated)]
1954        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1955                                                 depth: usize|
1956         -> Result<(), ()> {
1957            f(e, depth);
1958            Ok(())
1959        });
1960    }
1961
1962    #[deprecated = "Use `Visit::try_visit_post`."]
1963    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1964    where
1965        F: FnMut(&'a Self, usize) -> Result<(), E>,
1966    {
1967        #[allow(deprecated)]
1968        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1969            e.visit_fallible(depth, f)
1970        })?;
1971        f(self, depth)
1972    }
1973
1974    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1975    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1976    where
1977        F: FnMut(&'a Self, usize) -> Result<(), E>,
1978    {
1979        match self {
1980            HirRelationExpr::Constant { .. }
1981            | HirRelationExpr::Get { .. }
1982            | HirRelationExpr::CallTable { .. } => (),
1983            HirRelationExpr::Let { body, value, .. } => {
1984                f(value, depth)?;
1985                f(body, depth)?;
1986            }
1987            HirRelationExpr::LetRec {
1988                limit: _,
1989                bindings,
1990                body,
1991            } => {
1992                for (_, _, value, _) in bindings.iter() {
1993                    f(value, depth)?;
1994                }
1995                f(body, depth)?;
1996            }
1997            HirRelationExpr::Project { input, .. } => {
1998                f(input, depth)?;
1999            }
2000            HirRelationExpr::Map { input, .. } => {
2001                f(input, depth)?;
2002            }
2003            HirRelationExpr::Filter { input, .. } => {
2004                f(input, depth)?;
2005            }
2006            HirRelationExpr::Join { left, right, .. } => {
2007                f(left, depth)?;
2008                f(right, depth + 1)?;
2009            }
2010            HirRelationExpr::Reduce { input, .. } => {
2011                f(input, depth)?;
2012            }
2013            HirRelationExpr::Distinct { input } => {
2014                f(input, depth)?;
2015            }
2016            HirRelationExpr::TopK { input, .. } => {
2017                f(input, depth)?;
2018            }
2019            HirRelationExpr::Negate { input } => {
2020                f(input, depth)?;
2021            }
2022            HirRelationExpr::Threshold { input } => {
2023                f(input, depth)?;
2024            }
2025            HirRelationExpr::Union { base, inputs } => {
2026                f(base, depth)?;
2027                for input in inputs {
2028                    f(input, depth)?;
2029                }
2030            }
2031        }
2032        Ok(())
2033    }
2034
2035    #[deprecated = "Use `Visit::visit_mut_post` instead."]
2036    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2037    where
2038        F: FnMut(&mut Self, usize),
2039    {
2040        #[allow(deprecated)]
2041        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2042                                                     depth: usize|
2043         -> Result<(), ()> {
2044            f(e, depth);
2045            Ok(())
2046        });
2047    }
2048
2049    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2050    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2051    where
2052        F: FnMut(&mut Self, usize) -> Result<(), E>,
2053    {
2054        #[allow(deprecated)]
2055        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2056            e.visit_mut_fallible(depth, f)
2057        })?;
2058        f(self, depth)
2059    }
2060
2061    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2062    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2063    where
2064        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2065    {
2066        match self {
2067            HirRelationExpr::Constant { .. }
2068            | HirRelationExpr::Get { .. }
2069            | HirRelationExpr::CallTable { .. } => (),
2070            HirRelationExpr::Let { body, value, .. } => {
2071                f(value, depth)?;
2072                f(body, depth)?;
2073            }
2074            HirRelationExpr::LetRec {
2075                limit: _,
2076                bindings,
2077                body,
2078            } => {
2079                for (_, _, value, _) in bindings.iter_mut() {
2080                    f(value, depth)?;
2081                }
2082                f(body, depth)?;
2083            }
2084            HirRelationExpr::Project { input, .. } => {
2085                f(input, depth)?;
2086            }
2087            HirRelationExpr::Map { input, .. } => {
2088                f(input, depth)?;
2089            }
2090            HirRelationExpr::Filter { input, .. } => {
2091                f(input, depth)?;
2092            }
2093            HirRelationExpr::Join { left, right, .. } => {
2094                f(left, depth)?;
2095                f(right, depth + 1)?;
2096            }
2097            HirRelationExpr::Reduce { input, .. } => {
2098                f(input, depth)?;
2099            }
2100            HirRelationExpr::Distinct { input } => {
2101                f(input, depth)?;
2102            }
2103            HirRelationExpr::TopK { input, .. } => {
2104                f(input, depth)?;
2105            }
2106            HirRelationExpr::Negate { input } => {
2107                f(input, depth)?;
2108            }
2109            HirRelationExpr::Threshold { input } => {
2110                f(input, depth)?;
2111            }
2112            HirRelationExpr::Union { base, inputs } => {
2113                f(base, depth)?;
2114                for input in inputs {
2115                    f(input, depth)?;
2116                }
2117            }
2118        }
2119        Ok(())
2120    }
2121
2122    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2123    /// Visits all scalar expressions within the sub-tree of the given relation.
2124    ///
2125    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2126    /// which will be incremented when entering the RHS of a join or a subquery and
2127    /// presented to the supplied function `f`.
2128    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2129    where
2130        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2131    {
2132        #[allow(deprecated)]
2133        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2134                                         depth: usize|
2135         -> Result<(), E> {
2136            match e {
2137                HirRelationExpr::Join { on, .. } => {
2138                    f(on, depth)?;
2139                }
2140                HirRelationExpr::Map { scalars, .. } => {
2141                    for scalar in scalars {
2142                        f(scalar, depth)?;
2143                    }
2144                }
2145                HirRelationExpr::CallTable { exprs, .. } => {
2146                    for expr in exprs {
2147                        f(expr, depth)?;
2148                    }
2149                }
2150                HirRelationExpr::Filter { predicates, .. } => {
2151                    for predicate in predicates {
2152                        f(predicate, depth)?;
2153                    }
2154                }
2155                HirRelationExpr::Reduce { aggregates, .. } => {
2156                    for aggregate in aggregates {
2157                        f(&aggregate.expr, depth)?;
2158                    }
2159                }
2160                HirRelationExpr::TopK { limit, offset, .. } => {
2161                    if let Some(limit) = limit {
2162                        f(limit, depth)?;
2163                    }
2164                    f(offset, depth)?;
2165                }
2166                HirRelationExpr::Union { .. }
2167                | HirRelationExpr::Let { .. }
2168                | HirRelationExpr::LetRec { .. }
2169                | HirRelationExpr::Project { .. }
2170                | HirRelationExpr::Distinct { .. }
2171                | HirRelationExpr::Negate { .. }
2172                | HirRelationExpr::Threshold { .. }
2173                | HirRelationExpr::Constant { .. }
2174                | HirRelationExpr::Get { .. } => (),
2175            }
2176            Ok(())
2177        })
2178    }
2179
2180    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2181    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2182    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2183    where
2184        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2185    {
2186        #[allow(deprecated)]
2187        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2188                                             depth: usize|
2189         -> Result<(), E> {
2190            match e {
2191                HirRelationExpr::Join { on, .. } => {
2192                    f(on, depth)?;
2193                }
2194                HirRelationExpr::Map { scalars, .. } => {
2195                    for scalar in scalars.iter_mut() {
2196                        f(scalar, depth)?;
2197                    }
2198                }
2199                HirRelationExpr::CallTable { exprs, .. } => {
2200                    for expr in exprs.iter_mut() {
2201                        f(expr, depth)?;
2202                    }
2203                }
2204                HirRelationExpr::Filter { predicates, .. } => {
2205                    for predicate in predicates.iter_mut() {
2206                        f(predicate, depth)?;
2207                    }
2208                }
2209                HirRelationExpr::Reduce { aggregates, .. } => {
2210                    for aggregate in aggregates.iter_mut() {
2211                        f(&mut aggregate.expr, depth)?;
2212                    }
2213                }
2214                HirRelationExpr::TopK { limit, offset, .. } => {
2215                    if let Some(limit) = limit {
2216                        f(limit, depth)?;
2217                    }
2218                    f(offset, depth)?;
2219                }
2220                HirRelationExpr::Union { .. }
2221                | HirRelationExpr::Let { .. }
2222                | HirRelationExpr::LetRec { .. }
2223                | HirRelationExpr::Project { .. }
2224                | HirRelationExpr::Distinct { .. }
2225                | HirRelationExpr::Negate { .. }
2226                | HirRelationExpr::Threshold { .. }
2227                | HirRelationExpr::Constant { .. }
2228                | HirRelationExpr::Get { .. } => (),
2229            }
2230            Ok(())
2231        })
2232    }
2233
2234    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2235    /// Visits the column references in this relation expression.
2236    ///
2237    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2238    /// which will be incremented when entering the RHS of a join or a subquery and
2239    /// presented to the supplied function `f`.
2240    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2241    where
2242        F: FnMut(usize, &ColumnRef),
2243    {
2244        #[allow(deprecated)]
2245        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2246                                                           depth: usize|
2247         -> Result<(), ()> {
2248            e.visit_columns(depth, f);
2249            Ok(())
2250        });
2251    }
2252
2253    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2254    /// Like `visit_columns`, but permits mutating the column references.
2255    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2256    where
2257        F: FnMut(usize, &mut ColumnRef),
2258    {
2259        #[allow(deprecated)]
2260        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2261                                                               depth: usize|
2262         -> Result<(), ()> {
2263            e.visit_columns_mut(depth, f);
2264            Ok(())
2265        });
2266    }
2267
2268    /// Replaces any parameter references in the expression with the
2269    /// corresponding datum from `params`.
2270    pub fn bind_parameters(
2271        &mut self,
2272        scx: &StatementContext,
2273        lifetime: QueryLifetime,
2274        params: &Params,
2275    ) -> Result<(), PlanError> {
2276        #[allow(deprecated)]
2277        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2278            e.bind_parameters(scx, lifetime, params)
2279        })
2280    }
2281
2282    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2283        let mut contains_parameters = false;
2284        #[allow(deprecated)]
2285        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2286            if e.contains_parameters() {
2287                contains_parameters = true;
2288            }
2289            Ok::<(), PlanError>(())
2290        })?;
2291        Ok(contains_parameters)
2292    }
2293
2294    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2295    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2296        #[allow(deprecated)]
2297        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2298                                                               depth: usize|
2299         -> Result<(), ()> {
2300            e.splice_parameters(params, depth);
2301            Ok(())
2302        });
2303    }
2304
2305    /// Constructs a constant collection from specific rows and schema.
2306    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2307        let rows = rows
2308            .into_iter()
2309            .map(move |datums| Row::pack_slice(&datums))
2310            .collect();
2311        HirRelationExpr::Constant { rows, typ }
2312    }
2313
2314    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2315    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2316    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2317    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2318    /// finishing into a trivial finishing.
2319    pub fn finish_maintained(
2320        &mut self,
2321        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2322        group_size_hints: GroupSizeHints,
2323    ) {
2324        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2325            let old_finishing = mem::replace(
2326                finishing,
2327                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2328            );
2329            *self = HirRelationExpr::top_k(
2330                std::mem::replace(
2331                    self,
2332                    HirRelationExpr::Constant {
2333                        rows: vec![],
2334                        typ: SqlRelationType::new(Vec::new()),
2335                    },
2336                ),
2337                vec![],
2338                old_finishing.order_by,
2339                old_finishing.limit,
2340                old_finishing.offset,
2341                group_size_hints.limit_input_group_size,
2342            )
2343            .project(old_finishing.project);
2344        }
2345    }
2346
2347    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2348    ///
2349    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2350    /// parameter is not an HirScalarExpr anymore.)
2351    pub fn trivial_row_set_finishing_hir(
2352        arity: usize,
2353    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2354        RowSetFinishing {
2355            order_by: Vec::new(),
2356            limit: None,
2357            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2358            project: (0..arity).collect(),
2359        }
2360    }
2361
2362    /// True if the finishing does nothing to any result set.
2363    ///
2364    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2365    /// parameter is not an HirScalarExpr anymore.)
2366    pub fn is_trivial_row_set_finishing_hir(
2367        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2368        arity: usize,
2369    ) -> bool {
2370        rsf.limit.is_none()
2371            && rsf.order_by.is_empty()
2372            && rsf
2373                .offset
2374                .clone()
2375                .try_into_literal_int64()
2376                .is_ok_and(|o| o == 0)
2377            && rsf.project.iter().copied().eq(0..arity)
2378    }
2379
2380    /// The HirRelationExpr is considered potentially expensive if and only if
2381    /// at least one of the following conditions is true:
2382    ///
2383    ///  - It contains at least one CallTable or a Reduce operator.
2384    ///  - It contains at least one HirScalarExpr with a function call.
2385    ///
2386    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2387    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2388    pub fn could_run_expensive_function(&self) -> bool {
2389        let mut result = false;
2390        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2391            use HirRelationExpr::*;
2392            use HirScalarExpr::*;
2393
2394            self.visit_children(|scalar: &HirScalarExpr| {
2395                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2396                    result |= match scalar {
2397                        Column(..)
2398                        | Literal(..)
2399                        | CallUnmaterializable(..)
2400                        | If { .. }
2401                        | Parameter(..)
2402                        | Select(..)
2403                        | Exists(..) => false,
2404                        // Function calls are considered expensive
2405                        CallUnary { .. }
2406                        | CallBinary { .. }
2407                        | CallVariadic { .. }
2408                        | Windowing(..) => true,
2409                    };
2410                }) {
2411                    // Conservatively set `true` on RecursionLimitError.
2412                    result = true;
2413                }
2414            });
2415
2416            // CallTable has a table function; Reduce has an aggregate function.
2417            // Other constructs use MirScalarExpr to run a function
2418            result |= matches!(e, CallTable { .. } | Reduce { .. });
2419        }) {
2420            // Conservatively set `true` on RecursionLimitError.
2421            result = true;
2422        }
2423
2424        result
2425    }
2426
2427    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2428    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2429        let mut contains = false;
2430        self.visit_post(&mut |expr| {
2431            expr.visit_children(|expr: &HirScalarExpr| {
2432                contains = contains || expr.contains_temporal()
2433            })
2434        })?;
2435        Ok(contains)
2436    }
2437
2438    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2439    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2440        let mut contains = false;
2441        self.visit_post(&mut |expr| {
2442            expr.visit_children(|expr: &HirScalarExpr| {
2443                contains = contains || expr.contains_unmaterializable()
2444            })
2445        })?;
2446        Ok(contains)
2447    }
2448
2449    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
2450    /// [`UnmaterializableFunc::MzNow`].
2451    pub fn contains_unmaterializable_except_temporal(&self) -> Result<bool, RecursionLimitError> {
2452        let mut contains = false;
2453        self.visit_post(&mut |expr| {
2454            expr.visit_children(|expr: &HirScalarExpr| {
2455                contains = contains || expr.contains_unmaterializable_except_temporal()
2456            })
2457        })?;
2458        Ok(contains)
2459    }
2460}
2461
2462impl CollectionPlan for HirRelationExpr {
2463    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2464    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2465    /// (It does explore inside subqueries.)
2466    ///
2467    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2468    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2469    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2470        if let Self::Get {
2471            id: Id::Global(id), ..
2472        } = self
2473        {
2474            out.insert(*id);
2475        }
2476        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2477    }
2478}
2479
2480impl VisitChildren<Self> for HirRelationExpr {
2481    fn visit_children<F>(&self, mut f: F)
2482    where
2483        F: FnMut(&Self),
2484    {
2485        // subqueries of type HirRelationExpr might be wrapped in
2486        // Exists or Select variants within HirScalarExpr trees
2487        // attached at the current node, and we want to visit them as well
2488        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2489            #[allow(deprecated)]
2490            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2491                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2492                    f(expr.as_ref())
2493                }
2494                _ => (),
2495            });
2496        });
2497
2498        use HirRelationExpr::*;
2499        match self {
2500            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2501            Let {
2502                name: _,
2503                id: _,
2504                value,
2505                body,
2506            } => {
2507                f(value);
2508                f(body);
2509            }
2510            LetRec {
2511                limit: _,
2512                bindings,
2513                body,
2514            } => {
2515                for (_, _, value, _) in bindings.iter() {
2516                    f(value);
2517                }
2518                f(body);
2519            }
2520            Project { input, outputs: _ } => f(input),
2521            Map { input, scalars: _ } => {
2522                f(input);
2523            }
2524            CallTable { func: _, exprs: _ } => (),
2525            Filter {
2526                input,
2527                predicates: _,
2528            } => {
2529                f(input);
2530            }
2531            Join {
2532                left,
2533                right,
2534                on: _,
2535                kind: _,
2536            } => {
2537                f(left);
2538                f(right);
2539            }
2540            Reduce {
2541                input,
2542                group_key: _,
2543                aggregates: _,
2544                expected_group_size: _,
2545            } => {
2546                f(input);
2547            }
2548            Distinct { input }
2549            | TopK {
2550                input,
2551                group_key: _,
2552                order_key: _,
2553                limit: _,
2554                offset: _,
2555                expected_group_size: _,
2556            }
2557            | Negate { input }
2558            | Threshold { input } => {
2559                f(input);
2560            }
2561            Union { base, inputs } => {
2562                f(base);
2563                for input in inputs {
2564                    f(input);
2565                }
2566            }
2567        }
2568    }
2569
2570    fn visit_mut_children<F>(&mut self, mut f: F)
2571    where
2572        F: FnMut(&mut Self),
2573    {
2574        // subqueries of type HirRelationExpr might be wrapped in
2575        // Exists or Select variants within HirScalarExpr trees
2576        // attached at the current node, and we want to visit them as well
2577        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2578            #[allow(deprecated)]
2579            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2580                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2581                    f(expr.as_mut())
2582                }
2583                _ => (),
2584            });
2585        });
2586
2587        use HirRelationExpr::*;
2588        match self {
2589            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2590            Let {
2591                name: _,
2592                id: _,
2593                value,
2594                body,
2595            } => {
2596                f(value);
2597                f(body);
2598            }
2599            LetRec {
2600                limit: _,
2601                bindings,
2602                body,
2603            } => {
2604                for (_, _, value, _) in bindings.iter_mut() {
2605                    f(value);
2606                }
2607                f(body);
2608            }
2609            Project { input, outputs: _ } => f(input),
2610            Map { input, scalars: _ } => {
2611                f(input);
2612            }
2613            CallTable { func: _, exprs: _ } => (),
2614            Filter {
2615                input,
2616                predicates: _,
2617            } => {
2618                f(input);
2619            }
2620            Join {
2621                left,
2622                right,
2623                on: _,
2624                kind: _,
2625            } => {
2626                f(left);
2627                f(right);
2628            }
2629            Reduce {
2630                input,
2631                group_key: _,
2632                aggregates: _,
2633                expected_group_size: _,
2634            } => {
2635                f(input);
2636            }
2637            Distinct { input }
2638            | TopK {
2639                input,
2640                group_key: _,
2641                order_key: _,
2642                limit: _,
2643                offset: _,
2644                expected_group_size: _,
2645            }
2646            | Negate { input }
2647            | Threshold { input } => {
2648                f(input);
2649            }
2650            Union { base, inputs } => {
2651                f(base);
2652                for input in inputs {
2653                    f(input);
2654                }
2655            }
2656        }
2657    }
2658
2659    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2660    where
2661        F: FnMut(&Self) -> Result<(), E>,
2662        E: From<RecursionLimitError>,
2663    {
2664        // subqueries of type HirRelationExpr might be wrapped in
2665        // Exists or Select variants within HirScalarExpr trees
2666        // attached at the current node, and we want to visit them as well
2667        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2668            Visit::try_visit_post(expr, &mut |expr| match expr {
2669                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2670                    f(expr.as_ref())
2671                }
2672                _ => Ok(()),
2673            })
2674        })?;
2675
2676        use HirRelationExpr::*;
2677        match self {
2678            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2679            Let {
2680                name: _,
2681                id: _,
2682                value,
2683                body,
2684            } => {
2685                f(value)?;
2686                f(body)?;
2687            }
2688            LetRec {
2689                limit: _,
2690                bindings,
2691                body,
2692            } => {
2693                for (_, _, value, _) in bindings.iter() {
2694                    f(value)?;
2695                }
2696                f(body)?;
2697            }
2698            Project { input, outputs: _ } => f(input)?,
2699            Map { input, scalars: _ } => {
2700                f(input)?;
2701            }
2702            CallTable { func: _, exprs: _ } => (),
2703            Filter {
2704                input,
2705                predicates: _,
2706            } => {
2707                f(input)?;
2708            }
2709            Join {
2710                left,
2711                right,
2712                on: _,
2713                kind: _,
2714            } => {
2715                f(left)?;
2716                f(right)?;
2717            }
2718            Reduce {
2719                input,
2720                group_key: _,
2721                aggregates: _,
2722                expected_group_size: _,
2723            } => {
2724                f(input)?;
2725            }
2726            Distinct { input }
2727            | TopK {
2728                input,
2729                group_key: _,
2730                order_key: _,
2731                limit: _,
2732                offset: _,
2733                expected_group_size: _,
2734            }
2735            | Negate { input }
2736            | Threshold { input } => {
2737                f(input)?;
2738            }
2739            Union { base, inputs } => {
2740                f(base)?;
2741                for input in inputs {
2742                    f(input)?;
2743                }
2744            }
2745        }
2746        Ok(())
2747    }
2748
2749    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2750    where
2751        F: FnMut(&mut Self) -> Result<(), E>,
2752        E: From<RecursionLimitError>,
2753    {
2754        // subqueries of type HirRelationExpr might be wrapped in
2755        // Exists or Select variants within HirScalarExpr trees
2756        // attached at the current node, and we want to visit them as well
2757        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2758            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2759                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2760                    f(expr.as_mut())
2761                }
2762                _ => Ok(()),
2763            })
2764        })?;
2765
2766        use HirRelationExpr::*;
2767        match self {
2768            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2769            Let {
2770                name: _,
2771                id: _,
2772                value,
2773                body,
2774            } => {
2775                f(value)?;
2776                f(body)?;
2777            }
2778            LetRec {
2779                limit: _,
2780                bindings,
2781                body,
2782            } => {
2783                for (_, _, value, _) in bindings.iter_mut() {
2784                    f(value)?;
2785                }
2786                f(body)?;
2787            }
2788            Project { input, outputs: _ } => f(input)?,
2789            Map { input, scalars: _ } => {
2790                f(input)?;
2791            }
2792            CallTable { func: _, exprs: _ } => (),
2793            Filter {
2794                input,
2795                predicates: _,
2796            } => {
2797                f(input)?;
2798            }
2799            Join {
2800                left,
2801                right,
2802                on: _,
2803                kind: _,
2804            } => {
2805                f(left)?;
2806                f(right)?;
2807            }
2808            Reduce {
2809                input,
2810                group_key: _,
2811                aggregates: _,
2812                expected_group_size: _,
2813            } => {
2814                f(input)?;
2815            }
2816            Distinct { input }
2817            | TopK {
2818                input,
2819                group_key: _,
2820                order_key: _,
2821                limit: _,
2822                offset: _,
2823                expected_group_size: _,
2824            }
2825            | Negate { input }
2826            | Threshold { input } => {
2827                f(input)?;
2828            }
2829            Union { base, inputs } => {
2830                f(base)?;
2831                for input in inputs {
2832                    f(input)?;
2833                }
2834            }
2835        }
2836        Ok(())
2837    }
2838}
2839
2840impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2841    fn visit_children<F>(&self, mut f: F)
2842    where
2843        F: FnMut(&HirScalarExpr),
2844    {
2845        use HirRelationExpr::*;
2846        match self {
2847            Constant { rows: _, typ: _ }
2848            | Get { id: _, typ: _ }
2849            | Let {
2850                name: _,
2851                id: _,
2852                value: _,
2853                body: _,
2854            }
2855            | LetRec {
2856                limit: _,
2857                bindings: _,
2858                body: _,
2859            }
2860            | Project {
2861                input: _,
2862                outputs: _,
2863            } => (),
2864            Map { input: _, scalars } => {
2865                for scalar in scalars {
2866                    f(scalar);
2867                }
2868            }
2869            CallTable { func: _, exprs } => {
2870                for expr in exprs {
2871                    f(expr);
2872                }
2873            }
2874            Filter {
2875                input: _,
2876                predicates,
2877            } => {
2878                for predicate in predicates {
2879                    f(predicate);
2880                }
2881            }
2882            Join {
2883                left: _,
2884                right: _,
2885                on,
2886                kind: _,
2887            } => f(on),
2888            Reduce {
2889                input: _,
2890                group_key: _,
2891                aggregates,
2892                expected_group_size: _,
2893            } => {
2894                for aggregate in aggregates {
2895                    f(aggregate.expr.as_ref());
2896                }
2897            }
2898            TopK {
2899                input: _,
2900                group_key: _,
2901                order_key: _,
2902                limit,
2903                offset,
2904                expected_group_size: _,
2905            } => {
2906                if let Some(limit) = limit {
2907                    f(limit)
2908                }
2909                f(offset)
2910            }
2911            Distinct { input: _ }
2912            | Negate { input: _ }
2913            | Threshold { input: _ }
2914            | Union { base: _, inputs: _ } => (),
2915        }
2916    }
2917
2918    fn visit_mut_children<F>(&mut self, mut f: F)
2919    where
2920        F: FnMut(&mut HirScalarExpr),
2921    {
2922        use HirRelationExpr::*;
2923        match self {
2924            Constant { rows: _, typ: _ }
2925            | Get { id: _, typ: _ }
2926            | Let {
2927                name: _,
2928                id: _,
2929                value: _,
2930                body: _,
2931            }
2932            | LetRec {
2933                limit: _,
2934                bindings: _,
2935                body: _,
2936            }
2937            | Project {
2938                input: _,
2939                outputs: _,
2940            } => (),
2941            Map { input: _, scalars } => {
2942                for scalar in scalars {
2943                    f(scalar);
2944                }
2945            }
2946            CallTable { func: _, exprs } => {
2947                for expr in exprs {
2948                    f(expr);
2949                }
2950            }
2951            Filter {
2952                input: _,
2953                predicates,
2954            } => {
2955                for predicate in predicates {
2956                    f(predicate);
2957                }
2958            }
2959            Join {
2960                left: _,
2961                right: _,
2962                on,
2963                kind: _,
2964            } => f(on),
2965            Reduce {
2966                input: _,
2967                group_key: _,
2968                aggregates,
2969                expected_group_size: _,
2970            } => {
2971                for aggregate in aggregates {
2972                    f(aggregate.expr.as_mut());
2973                }
2974            }
2975            TopK {
2976                input: _,
2977                group_key: _,
2978                order_key: _,
2979                limit,
2980                offset,
2981                expected_group_size: _,
2982            } => {
2983                if let Some(limit) = limit {
2984                    f(limit)
2985                }
2986                f(offset)
2987            }
2988            Distinct { input: _ }
2989            | Negate { input: _ }
2990            | Threshold { input: _ }
2991            | Union { base: _, inputs: _ } => (),
2992        }
2993    }
2994
2995    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2996    where
2997        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2998        E: From<RecursionLimitError>,
2999    {
3000        use HirRelationExpr::*;
3001        match self {
3002            Constant { rows: _, typ: _ }
3003            | Get { id: _, typ: _ }
3004            | Let {
3005                name: _,
3006                id: _,
3007                value: _,
3008                body: _,
3009            }
3010            | LetRec {
3011                limit: _,
3012                bindings: _,
3013                body: _,
3014            }
3015            | Project {
3016                input: _,
3017                outputs: _,
3018            } => (),
3019            Map { input: _, scalars } => {
3020                for scalar in scalars {
3021                    f(scalar)?;
3022                }
3023            }
3024            CallTable { func: _, exprs } => {
3025                for expr in exprs {
3026                    f(expr)?;
3027                }
3028            }
3029            Filter {
3030                input: _,
3031                predicates,
3032            } => {
3033                for predicate in predicates {
3034                    f(predicate)?;
3035                }
3036            }
3037            Join {
3038                left: _,
3039                right: _,
3040                on,
3041                kind: _,
3042            } => f(on)?,
3043            Reduce {
3044                input: _,
3045                group_key: _,
3046                aggregates,
3047                expected_group_size: _,
3048            } => {
3049                for aggregate in aggregates {
3050                    f(aggregate.expr.as_ref())?;
3051                }
3052            }
3053            TopK {
3054                input: _,
3055                group_key: _,
3056                order_key: _,
3057                limit,
3058                offset,
3059                expected_group_size: _,
3060            } => {
3061                if let Some(limit) = limit {
3062                    f(limit)?
3063                }
3064                f(offset)?
3065            }
3066            Distinct { input: _ }
3067            | Negate { input: _ }
3068            | Threshold { input: _ }
3069            | Union { base: _, inputs: _ } => (),
3070        }
3071        Ok(())
3072    }
3073
3074    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3075    where
3076        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3077        E: From<RecursionLimitError>,
3078    {
3079        use HirRelationExpr::*;
3080        match self {
3081            Constant { rows: _, typ: _ }
3082            | Get { id: _, typ: _ }
3083            | Let {
3084                name: _,
3085                id: _,
3086                value: _,
3087                body: _,
3088            }
3089            | LetRec {
3090                limit: _,
3091                bindings: _,
3092                body: _,
3093            }
3094            | Project {
3095                input: _,
3096                outputs: _,
3097            } => (),
3098            Map { input: _, scalars } => {
3099                for scalar in scalars {
3100                    f(scalar)?;
3101                }
3102            }
3103            CallTable { func: _, exprs } => {
3104                for expr in exprs {
3105                    f(expr)?;
3106                }
3107            }
3108            Filter {
3109                input: _,
3110                predicates,
3111            } => {
3112                for predicate in predicates {
3113                    f(predicate)?;
3114                }
3115            }
3116            Join {
3117                left: _,
3118                right: _,
3119                on,
3120                kind: _,
3121            } => f(on)?,
3122            Reduce {
3123                input: _,
3124                group_key: _,
3125                aggregates,
3126                expected_group_size: _,
3127            } => {
3128                for aggregate in aggregates {
3129                    f(aggregate.expr.as_mut())?;
3130                }
3131            }
3132            TopK {
3133                input: _,
3134                group_key: _,
3135                order_key: _,
3136                limit,
3137                offset,
3138                expected_group_size: _,
3139            } => {
3140                if let Some(limit) = limit {
3141                    f(limit)?
3142                }
3143                f(offset)?
3144            }
3145            Distinct { input: _ }
3146            | Negate { input: _ }
3147            | Threshold { input: _ }
3148            | Union { base: _, inputs: _ } => (),
3149        }
3150        Ok(())
3151    }
3152}
3153
3154impl HirScalarExpr {
3155    pub fn name(&self) -> Option<Arc<str>> {
3156        use HirScalarExpr::*;
3157        match self {
3158            Column(_, name)
3159            | Parameter(_, name)
3160            | Literal(_, _, name)
3161            | CallUnmaterializable(_, name)
3162            | CallUnary { name, .. }
3163            | CallBinary { name, .. }
3164            | CallVariadic { name, .. }
3165            | If { name, .. }
3166            | Exists(_, name)
3167            | Select(_, name)
3168            | Windowing(_, name) => name.0.clone(),
3169        }
3170    }
3171
3172    /// Replaces any parameter references in the expression with the
3173    /// corresponding datum in `params`.
3174    pub fn bind_parameters(
3175        &mut self,
3176        scx: &StatementContext,
3177        lifetime: QueryLifetime,
3178        params: &Params,
3179    ) -> Result<(), PlanError> {
3180        #[allow(deprecated)]
3181        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3182            if let HirScalarExpr::Parameter(n, name) = e {
3183                let datum = match params.datums.iter().nth(*n - 1) {
3184                    None => return Err(PlanError::UnknownParameter(*n)),
3185                    Some(datum) => datum,
3186                };
3187                let scalar_type = &params.execute_types[*n - 1];
3188                let row = Row::pack([datum]);
3189                let column_type = scalar_type.clone().nullable(datum.is_null());
3190
3191                let name = if let Some(name) = &name.0 {
3192                    Some(Arc::clone(name))
3193                } else {
3194                    Some(Arc::from(format!("${n}")))
3195                };
3196
3197                let qcx = QueryContext::root(scx, lifetime);
3198                let ecx = execute_expr_context(&qcx);
3199
3200                *e = plan_cast(
3201                    &ecx,
3202                    *EXECUTE_CAST_CONTEXT,
3203                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3204                    &params.expected_types[*n - 1],
3205                )
3206                .expect("checked in plan_params");
3207            }
3208            Ok(())
3209        })
3210    }
3211
3212    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3213    /// replaced with the corresponding expression fragment from `params` rather
3214    /// than a datum.
3215    ///
3216    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3217    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3218    /// in `self` that refer to invalid indices of `params` will cause a panic.
3219    ///
3220    /// Column references in parameters will be corrected to account for the
3221    /// depth at which they are spliced.
3222    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3223        #[allow(deprecated)]
3224        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3225                                                        e: &mut HirScalarExpr|
3226         -> Result<(), ()> {
3227            if let HirScalarExpr::Parameter(i, _name) = e {
3228                *e = params[*i - 1].clone();
3229                // Correct any column references in the parameter expression for
3230                // its new depth.
3231                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3232                    if col.level >= d {
3233                        col.level += depth
3234                    }
3235                });
3236            }
3237            Ok(())
3238        });
3239    }
3240
3241    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3242    pub fn contains_temporal(&self) -> bool {
3243        let mut contains = false;
3244        #[allow(deprecated)]
3245        self.visit_post_nolimit(&mut |e| {
3246            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3247                contains = true;
3248            }
3249        });
3250        contains
3251    }
3252
3253    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3254    pub fn contains_unmaterializable(&self) -> bool {
3255        let mut contains = false;
3256        #[allow(deprecated)]
3257        self.visit_post_nolimit(&mut |e| {
3258            if let Self::CallUnmaterializable(_, _) = e {
3259                contains = true;
3260            }
3261        });
3262        contains
3263    }
3264
3265    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
3266    /// [`UnmaterializableFunc::MzNow`].
3267    pub fn contains_unmaterializable_except_temporal(&self) -> bool {
3268        let mut contains = false;
3269        #[allow(deprecated)]
3270        self.visit_post_nolimit(&mut |e| {
3271            if let Self::CallUnmaterializable(f, _) = e {
3272                if *f != UnmaterializableFunc::MzNow {
3273                    contains = true;
3274                }
3275            }
3276        });
3277        contains
3278    }
3279
3280    /// Constructs an unnamed column reference in the current scope.
3281    /// Use [`HirScalarExpr::named_column`] when a name is known.
3282    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3283    pub fn column(index: usize) -> HirScalarExpr {
3284        HirScalarExpr::Column(
3285            ColumnRef {
3286                level: 0,
3287                column: index,
3288            },
3289            TreatAsEqual(None),
3290        )
3291    }
3292
3293    /// Constructs an unnamed column reference.
3294    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3295        HirScalarExpr::Column(cr, TreatAsEqual(None))
3296    }
3297
3298    /// Constructs a named column reference.
3299    /// Names are interned by a `NameManager`.
3300    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3301        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3302    }
3303
3304    pub fn parameter(n: usize) -> HirScalarExpr {
3305        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3306    }
3307
3308    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3309        let col_type = scalar_type.nullable(datum.is_null());
3310        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3311        let row = Row::pack([datum]);
3312        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3313    }
3314
3315    pub fn literal_true() -> HirScalarExpr {
3316        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3317    }
3318
3319    pub fn literal_false() -> HirScalarExpr {
3320        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3321    }
3322
3323    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3324        HirScalarExpr::literal(Datum::Null, scalar_type)
3325    }
3326
3327    pub fn literal_1d_array(
3328        datums: Vec<Datum>,
3329        element_scalar_type: SqlScalarType,
3330    ) -> Result<HirScalarExpr, PlanError> {
3331        let scalar_type = match element_scalar_type {
3332            SqlScalarType::Array(_) => {
3333                sql_bail!("cannot build array from array type");
3334            }
3335            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3336        };
3337
3338        let mut row = Row::default();
3339        row.packer()
3340            .try_push_array(
3341                &[ArrayDimension {
3342                    lower_bound: 1,
3343                    length: datums.len(),
3344                }],
3345                datums,
3346            )
3347            .expect("array constructed to be valid");
3348
3349        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3350    }
3351
3352    pub fn as_literal(&self) -> Option<Datum<'_>> {
3353        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3354            Some(row.unpack_first())
3355        } else {
3356            None
3357        }
3358    }
3359
3360    pub fn is_literal_true(&self) -> bool {
3361        Some(Datum::True) == self.as_literal()
3362    }
3363
3364    pub fn is_literal_false(&self) -> bool {
3365        Some(Datum::False) == self.as_literal()
3366    }
3367
3368    pub fn is_literal_null(&self) -> bool {
3369        Some(Datum::Null) == self.as_literal()
3370    }
3371
3372    /// Return true iff `self` consists only of literals, materializable function calls, and
3373    /// if-else statements.
3374    pub fn is_constant(&self) -> bool {
3375        let mut worklist = vec![self];
3376        while let Some(expr) = worklist.pop() {
3377            match expr {
3378                Self::Literal(..) => {
3379                    // leaf node, do nothing
3380                }
3381                Self::CallUnary { expr, .. } => {
3382                    worklist.push(expr);
3383                }
3384                Self::CallBinary {
3385                    func: _,
3386                    expr1,
3387                    expr2,
3388                    name: _,
3389                } => {
3390                    worklist.push(expr1);
3391                    worklist.push(expr2);
3392                }
3393                Self::CallVariadic {
3394                    func: _,
3395                    exprs,
3396                    name: _,
3397                } => {
3398                    worklist.extend(exprs.iter());
3399                }
3400                // (CallUnmaterializable is not allowed)
3401                Self::If {
3402                    cond,
3403                    then,
3404                    els,
3405                    name: _,
3406                } => {
3407                    worklist.push(cond);
3408                    worklist.push(then);
3409                    worklist.push(els);
3410                }
3411                _ => {
3412                    return false; // Any other node makes `self` non-constant.
3413                }
3414            }
3415        }
3416        true
3417    }
3418
3419    pub fn call_unary(self, func: UnaryFunc) -> Self {
3420        HirScalarExpr::CallUnary {
3421            func,
3422            expr: Box::new(self),
3423            name: NameMetadata::default(),
3424        }
3425    }
3426
3427    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3428        HirScalarExpr::CallBinary {
3429            func: func.into(),
3430            expr1: Box::new(self),
3431            expr2: Box::new(other),
3432            name: NameMetadata::default(),
3433        }
3434    }
3435
3436    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3437        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3438    }
3439
3440    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
3441        HirScalarExpr::CallVariadic {
3442            func: func.into(),
3443            exprs,
3444            name: NameMetadata::default(),
3445        }
3446    }
3447
3448    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3449        HirScalarExpr::If {
3450            cond: Box::new(cond),
3451            then: Box::new(then),
3452            els: Box::new(els),
3453            name: NameMetadata::default(),
3454        }
3455    }
3456
3457    pub fn windowing(expr: WindowExpr) -> Self {
3458        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3459    }
3460
3461    pub fn or(self, other: Self) -> Self {
3462        HirScalarExpr::call_variadic(Or, vec![self, other])
3463    }
3464
3465    pub fn and(self, other: Self) -> Self {
3466        HirScalarExpr::call_variadic(And, vec![self, other])
3467    }
3468
3469    pub fn not(self) -> Self {
3470        self.call_unary(UnaryFunc::Not(func::Not))
3471    }
3472
3473    pub fn call_is_null(self) -> Self {
3474        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3475    }
3476
3477    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3478    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3479        match args.len() {
3480            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3481            1 => args.swap_remove(0),
3482            _ => HirScalarExpr::call_variadic(And, args),
3483        }
3484    }
3485
3486    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3487    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3488        match args.len() {
3489            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3490            1 => args.swap_remove(0),
3491            _ => HirScalarExpr::call_variadic(Or, args),
3492        }
3493    }
3494
3495    pub fn take(&mut self) -> Self {
3496        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3497    }
3498
3499    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3500    /// Visits the column references in this scalar expression.
3501    ///
3502    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3503    /// which will be incremented with each subquery entered and presented to the supplied
3504    /// function `f`.
3505    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3506    where
3507        F: FnMut(usize, &ColumnRef),
3508    {
3509        #[allow(deprecated)]
3510        let _ = self.visit_recursively(depth, &mut |depth: usize,
3511                                                    e: &HirScalarExpr|
3512         -> Result<(), ()> {
3513            if let HirScalarExpr::Column(col, _name) = e {
3514                f(depth, col)
3515            }
3516            Ok(())
3517        });
3518    }
3519
3520    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3521    /// Like `visit_columns`, but permits mutating the column references.
3522    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3523    where
3524        F: FnMut(usize, &mut ColumnRef),
3525    {
3526        #[allow(deprecated)]
3527        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3528                                                        e: &mut HirScalarExpr|
3529         -> Result<(), ()> {
3530            if let HirScalarExpr::Column(col, _name) = e {
3531                f(depth, col)
3532            }
3533            Ok(())
3534        });
3535    }
3536
3537    /// Visits those column references in this scalar expression that refer to the root
3538    /// level. These include column references that are at the root level, as well as column
3539    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3540    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3541    /// "root level" to be `self`'s level.)
3542    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3543    where
3544        F: FnMut(usize),
3545    {
3546        #[allow(deprecated)]
3547        let _ = self.visit_recursively(0, &mut |depth: usize,
3548                                                e: &HirScalarExpr|
3549         -> Result<(), ()> {
3550            if let HirScalarExpr::Column(col, _name) = e {
3551                if col.level == depth {
3552                    f(col.column)
3553                }
3554            }
3555            Ok(())
3556        });
3557    }
3558
3559    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3560    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3561    where
3562        F: FnMut(&mut usize),
3563    {
3564        #[allow(deprecated)]
3565        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3566                                                    e: &mut HirScalarExpr|
3567         -> Result<(), ()> {
3568            if let HirScalarExpr::Column(col, _name) = e {
3569                if col.level == depth {
3570                    f(&mut col.column)
3571                }
3572            }
3573            Ok(())
3574        });
3575    }
3576
3577    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3578    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3579    /// in them. It takes the current depth of the expression and increases it when
3580    /// entering a subquery.
3581    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3582    where
3583        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3584    {
3585        match self {
3586            HirScalarExpr::Literal(..)
3587            | HirScalarExpr::Parameter(..)
3588            | HirScalarExpr::CallUnmaterializable(..)
3589            | HirScalarExpr::Column(..) => (),
3590            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3591            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3592                expr1.visit_recursively(depth, f)?;
3593                expr2.visit_recursively(depth, f)?;
3594            }
3595            HirScalarExpr::CallVariadic { exprs, .. } => {
3596                for expr in exprs {
3597                    expr.visit_recursively(depth, f)?;
3598                }
3599            }
3600            HirScalarExpr::If {
3601                cond,
3602                then,
3603                els,
3604                name: _,
3605            } => {
3606                cond.visit_recursively(depth, f)?;
3607                then.visit_recursively(depth, f)?;
3608                els.visit_recursively(depth, f)?;
3609            }
3610            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3611                #[allow(deprecated)]
3612                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3613                    e.visit_recursively(depth, f)
3614                })?;
3615            }
3616            HirScalarExpr::Windowing(expr, _name) => {
3617                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3618            }
3619        }
3620        f(depth, self)
3621    }
3622
3623    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3624    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3625    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3626    where
3627        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3628    {
3629        match self {
3630            HirScalarExpr::Literal(..)
3631            | HirScalarExpr::Parameter(..)
3632            | HirScalarExpr::CallUnmaterializable(..)
3633            | HirScalarExpr::Column(..) => (),
3634            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3635            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3636                expr1.visit_recursively_mut(depth, f)?;
3637                expr2.visit_recursively_mut(depth, f)?;
3638            }
3639            HirScalarExpr::CallVariadic { exprs, .. } => {
3640                for expr in exprs {
3641                    expr.visit_recursively_mut(depth, f)?;
3642                }
3643            }
3644            HirScalarExpr::If {
3645                cond,
3646                then,
3647                els,
3648                name: _,
3649            } => {
3650                cond.visit_recursively_mut(depth, f)?;
3651                then.visit_recursively_mut(depth, f)?;
3652                els.visit_recursively_mut(depth, f)?;
3653            }
3654            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3655                #[allow(deprecated)]
3656                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3657                    e.visit_recursively_mut(depth, f)
3658                })?;
3659            }
3660            HirScalarExpr::Windowing(expr, _name) => {
3661                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3662            }
3663        }
3664        f(depth, self)
3665    }
3666
3667    /// Attempts to simplify self into a literal.
3668    ///
3669    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3670    /// an evaluation error occurs during simplification, or if self contains
3671    /// - a subquery
3672    /// - a column reference to an outer level
3673    /// - a parameter
3674    /// - a window function call
3675    fn simplify_to_literal(self) -> Option<Row> {
3676        let mut expr = self
3677            .lower_uncorrelated(crate::plan::lowering::Config::default())
3678            .ok()?;
3679        // Using MIR evaluation with repr types is fine here: the
3680        // result is an untyped Row, so any intermediate type
3681        // canonicalization is discarded.
3682        expr.reduce(&[]);
3683        match expr {
3684            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3685            _ => None,
3686        }
3687    }
3688
3689    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3690    /// or an evaluation error occurs during simplification), it returns
3691    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3692    ///
3693    /// The returned error is an _internal_ error if the expression contains
3694    /// - a subquery
3695    /// - a column reference to an outer level
3696    /// - a parameter
3697    /// - a window function call
3698    ///
3699    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3700    /// msg.
3701    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3702        let mut expr = self
3703            .lower_uncorrelated(crate::plan::lowering::Config::default())
3704            .map_err(|err| {
3705                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3706            })?;
3707        // Using MIR evaluation with repr types is fine here: the
3708        // result is an untyped Row, so any intermediate type
3709        // canonicalization is discarded.
3710        expr.reduce(&[]);
3711        match expr {
3712            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3713            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3714                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3715            ),
3716            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3717                "Not a constant".to_string(),
3718            )),
3719        }
3720    }
3721
3722    /// Attempts to simplify this expression to a literal 64-bit integer.
3723    ///
3724    /// Returns `None` if this expression cannot be simplified, e.g. because it
3725    /// contains non-literal values.
3726    ///
3727    /// # Panics
3728    ///
3729    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3730    pub fn into_literal_int64(self) -> Option<i64> {
3731        self.simplify_to_literal().and_then(|row| {
3732            let datum = row.unpack_first();
3733            if datum.is_null() {
3734                None
3735            } else {
3736                Some(datum.unwrap_int64())
3737            }
3738        })
3739    }
3740
3741    /// Attempts to simplify this expression to a literal string.
3742    ///
3743    /// Returns `None` if this expression cannot be simplified, e.g. because it
3744    /// contains non-literal values.
3745    ///
3746    /// # Panics
3747    ///
3748    /// Panics if this expression does not have type [`SqlScalarType::String`].
3749    pub fn into_literal_string(self) -> Option<String> {
3750        self.simplify_to_literal().and_then(|row| {
3751            let datum = row.unpack_first();
3752            if datum.is_null() {
3753                None
3754            } else {
3755                Some(datum.unwrap_str().to_owned())
3756            }
3757        })
3758    }
3759
3760    /// Attempts to simplify this expression to a literal MzTimestamp.
3761    ///
3762    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3763    /// simplified, e.g. because it contains non-literal values or a cast fails.
3764    ///
3765    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3766    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3767    /// See `try_into_literal_int64` as an example.
3768    ///
3769    /// # Panics
3770    ///
3771    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3772    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3773        self.simplify_to_literal().and_then(|row| {
3774            let datum = row.unpack_first();
3775            if datum.is_null() {
3776                None
3777            } else {
3778                Some(datum.unwrap_mz_timestamp())
3779            }
3780        })
3781    }
3782
3783    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3784    /// returns it as an i64.
3785    ///
3786    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3787    /// - it's not a constant expression (as determined by `is_constant`)
3788    /// - evaluates to null
3789    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3790    ///
3791    /// # Panics
3792    ///
3793    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3794    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3795        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3796        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3797        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3798        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3799        // preconditions also of all the other into_literal_... functions.)
3800        if !self.is_constant() {
3801            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3802                "Expected a constant expression, got {}",
3803                self
3804            )));
3805        }
3806        self.clone()
3807            .simplify_to_literal_with_result()
3808            .and_then(|row| {
3809                let datum = row.unpack_first();
3810                if datum.is_null() {
3811                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3812                        "Expected an expression that evaluates to a non-null value, got {}",
3813                        self
3814                    )))
3815                } else {
3816                    Ok(datum.unwrap_int64())
3817                }
3818            })
3819    }
3820
3821    pub fn contains_parameters(&self) -> bool {
3822        let mut contains_parameters = false;
3823        #[allow(deprecated)]
3824        let _ = self.visit_recursively(0, &mut |_depth: usize,
3825                                                expr: &HirScalarExpr|
3826         -> Result<(), ()> {
3827            if let HirScalarExpr::Parameter(..) = expr {
3828                contains_parameters = true;
3829            }
3830            Ok(())
3831        });
3832        contains_parameters
3833    }
3834}
3835
3836impl VisitChildren<Self> for HirScalarExpr {
3837    fn visit_children<F>(&self, mut f: F)
3838    where
3839        F: FnMut(&Self),
3840    {
3841        use HirScalarExpr::*;
3842        match self {
3843            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3844            CallUnary { expr, .. } => f(expr),
3845            CallBinary { expr1, expr2, .. } => {
3846                f(expr1);
3847                f(expr2);
3848            }
3849            CallVariadic { exprs, .. } => {
3850                for expr in exprs {
3851                    f(expr);
3852                }
3853            }
3854            If {
3855                cond,
3856                then,
3857                els,
3858                name: _,
3859            } => {
3860                f(cond);
3861                f(then);
3862                f(els);
3863            }
3864            Exists(..) | Select(..) => (),
3865            Windowing(expr, _name) => expr.visit_children(f),
3866        }
3867    }
3868
3869    fn visit_mut_children<F>(&mut self, mut f: F)
3870    where
3871        F: FnMut(&mut Self),
3872    {
3873        use HirScalarExpr::*;
3874        match self {
3875            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3876            CallUnary { expr, .. } => f(expr),
3877            CallBinary { expr1, expr2, .. } => {
3878                f(expr1);
3879                f(expr2);
3880            }
3881            CallVariadic { exprs, .. } => {
3882                for expr in exprs {
3883                    f(expr);
3884                }
3885            }
3886            If {
3887                cond,
3888                then,
3889                els,
3890                name: _,
3891            } => {
3892                f(cond);
3893                f(then);
3894                f(els);
3895            }
3896            Exists(..) | Select(..) => (),
3897            Windowing(expr, _name) => expr.visit_mut_children(f),
3898        }
3899    }
3900
3901    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3902    where
3903        F: FnMut(&Self) -> Result<(), E>,
3904        E: From<RecursionLimitError>,
3905    {
3906        use HirScalarExpr::*;
3907        match self {
3908            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3909            CallUnary { expr, .. } => f(expr)?,
3910            CallBinary { expr1, expr2, .. } => {
3911                f(expr1)?;
3912                f(expr2)?;
3913            }
3914            CallVariadic { exprs, .. } => {
3915                for expr in exprs {
3916                    f(expr)?;
3917                }
3918            }
3919            If {
3920                cond,
3921                then,
3922                els,
3923                name: _,
3924            } => {
3925                f(cond)?;
3926                f(then)?;
3927                f(els)?;
3928            }
3929            Exists(..) | Select(..) => (),
3930            Windowing(expr, _name) => expr.try_visit_children(f)?,
3931        }
3932        Ok(())
3933    }
3934
3935    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3936    where
3937        F: FnMut(&mut Self) -> Result<(), E>,
3938        E: From<RecursionLimitError>,
3939    {
3940        use HirScalarExpr::*;
3941        match self {
3942            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3943            CallUnary { expr, .. } => f(expr)?,
3944            CallBinary { expr1, expr2, .. } => {
3945                f(expr1)?;
3946                f(expr2)?;
3947            }
3948            CallVariadic { exprs, .. } => {
3949                for expr in exprs {
3950                    f(expr)?;
3951                }
3952            }
3953            If {
3954                cond,
3955                then,
3956                els,
3957                name: _,
3958            } => {
3959                f(cond)?;
3960                f(then)?;
3961                f(els)?;
3962            }
3963            Exists(..) | Select(..) => (),
3964            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3965        }
3966        Ok(())
3967    }
3968}
3969
3970impl AbstractExpr for HirScalarExpr {
3971    type Type = SqlColumnType;
3972
3973    fn typ(
3974        &self,
3975        outers: &[SqlRelationType],
3976        inner: &SqlRelationType,
3977        params: &BTreeMap<usize, SqlScalarType>,
3978    ) -> Self::Type {
3979        stack::maybe_grow(|| match self {
3980            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3981                if *level == 0 {
3982                    inner.column_types[*column].clone()
3983                } else {
3984                    outers[*level - 1].column_types[*column].clone()
3985                }
3986            }
3987            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3988            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3989            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(),
3990            HirScalarExpr::CallUnary {
3991                expr,
3992                func,
3993                name: _,
3994            } => func.output_sql_type(expr.typ(outers, inner, params)),
3995            HirScalarExpr::CallBinary {
3996                expr1,
3997                expr2,
3998                func,
3999                name: _,
4000            } => func.output_sql_type(&[
4001                expr1.typ(outers, inner, params),
4002                expr2.typ(outers, inner, params),
4003            ]),
4004            HirScalarExpr::CallVariadic {
4005                exprs,
4006                func,
4007                name: _,
4008            } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
4009            HirScalarExpr::If {
4010                cond: _,
4011                then,
4012                els,
4013                name: _,
4014            } => {
4015                let then_type = then.typ(outers, inner, params);
4016                let else_type = els.typ(outers, inner, params);
4017                then_type.sql_union(&else_type).unwrap() // HIR deliberately not using `union`
4018            }
4019            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
4020            HirScalarExpr::Select(expr, _name) => {
4021                let mut outers = outers.to_vec();
4022                outers.insert(0, inner.clone());
4023                expr.typ(&outers, params)
4024                    .column_types
4025                    .into_element()
4026                    .nullable(true)
4027            }
4028            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
4029        })
4030    }
4031}
4032
4033impl AggregateExpr {
4034    pub fn typ(
4035        &self,
4036        outers: &[SqlRelationType],
4037        inner: &SqlRelationType,
4038        params: &BTreeMap<usize, SqlScalarType>,
4039    ) -> SqlColumnType {
4040        self.func
4041            .output_sql_type(self.expr.typ(outers, inner, params))
4042    }
4043
4044    /// Returns whether the expression is COUNT(*) or not.  Note that
4045    /// when we define the count builtin in sql::func, we convert
4046    /// COUNT(*) to COUNT(true), making it indistinguishable from
4047    /// literal COUNT(true), but we prefer to consider this as the
4048    /// former.
4049    ///
4050    /// (MIR has the same `is_count_asterisk`.)
4051    pub fn is_count_asterisk(&self) -> bool {
4052        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
4053    }
4054}