Skip to main content

mz_sql/plan/
hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! This file houses HIR, a representation of a SQL plan that is parallel to MIR, but represents
11//! an earlier phase of planning. It's structurally very similar to MIR, with some differences
12//! which are noted below. It gets turned into MIR via a call to lower().
13
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23// these happen to be unchanged at the moment, but there might be additions later
24use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25use mz_expr::func::variadic::{And, Or};
26pub use mz_expr::{
27    BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
28};
29use mz_ore::collections::CollectionExt;
30use mz_ore::error::ErrorExt;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::separated;
33use mz_ore::treat_as_equal::TreatAsEqual;
34use mz_ore::{soft_assert_or_log, stack};
35use mz_repr::adt::array::ArrayDimension;
36use mz_repr::adt::numeric::NumericMaxScale;
37use mz_repr::*;
38use serde::{Deserialize, Serialize};
39
40use crate::plan::error::PlanError;
41use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
42use crate::plan::typeconv::{self, CastContext, plan_cast};
43use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
44
45use super::plan_utils::GroupSizeHints;
46
47#[allow(missing_debug_implementations)]
48pub struct Hir;
49
50impl IR for Hir {
51    type Relation = HirRelationExpr;
52    type Scalar = HirScalarExpr;
53}
54
55impl AlgExcept for Hir {
56    fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
57        if *all {
58            let rhs = rhs.negate();
59            HirRelationExpr::union(lhs, rhs).threshold()
60        } else {
61            let lhs = lhs.distinct();
62            let rhs = rhs.distinct().negate();
63            HirRelationExpr::union(lhs, rhs).threshold()
64        }
65    }
66
67    fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
68        let mut result = None;
69
70        use HirRelationExpr::*;
71        if let Threshold { input } = expr {
72            if let Union { base: lhs, inputs } = input.as_ref() {
73                if let [rhs] = &inputs[..] {
74                    if let Negate { input: rhs } = rhs {
75                        match (lhs.as_ref(), rhs.as_ref()) {
76                            (Distinct { input: lhs }, Distinct { input: rhs }) => {
77                                let all = false;
78                                let lhs = lhs.as_ref();
79                                let rhs = rhs.as_ref();
80                                result = Some(Except { all, lhs, rhs })
81                            }
82                            (lhs, rhs) => {
83                                let all = true;
84                                result = Some(Except { all, lhs, rhs })
85                            }
86                        }
87                    }
88                }
89            }
90        }
91
92        result
93    }
94}
95
96#[derive(
97    Debug,
98    Clone,
99    PartialEq,
100    Eq,
101    PartialOrd,
102    Ord,
103    Hash,
104    Serialize,
105    Deserialize
106)]
107/// Just like [`mz_expr::MirRelationExpr`], except where otherwise noted below.
108pub enum HirRelationExpr {
109    Constant {
110        rows: Vec<Row>,
111        typ: SqlRelationType,
112    },
113    Get {
114        id: mz_expr::Id,
115        typ: SqlRelationType,
116    },
117    /// Mutually recursive CTE
118    LetRec {
119        /// Maximum number of iterations to evaluate. If None, then there is no limit.
120        limit: Option<LetRecLimit>,
121        /// List of bindings all of which are in scope of each other.
122        bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
123        /// Result of the AST node.
124        body: Box<HirRelationExpr>,
125    },
126    /// CTE
127    Let {
128        name: String,
129        /// The identifier to be used in `Get` variants to retrieve `value`.
130        id: mz_expr::LocalId,
131        /// The collection to be bound to `name`.
132        value: Box<HirRelationExpr>,
133        /// The result of the `Let`, evaluated with `name` bound to `value`.
134        body: Box<HirRelationExpr>,
135    },
136    Project {
137        input: Box<HirRelationExpr>,
138        outputs: Vec<usize>,
139    },
140    Map {
141        input: Box<HirRelationExpr>,
142        scalars: Vec<HirScalarExpr>,
143    },
144    CallTable {
145        func: TableFunc,
146        exprs: Vec<HirScalarExpr>,
147    },
148    Filter {
149        input: Box<HirRelationExpr>,
150        predicates: Vec<HirScalarExpr>,
151    },
152    /// Unlike MirRelationExpr, we haven't yet compiled LeftOuter/RightOuter/FullOuter
153    /// joins away into more primitive exprs
154    Join {
155        left: Box<HirRelationExpr>,
156        right: Box<HirRelationExpr>,
157        on: HirScalarExpr,
158        kind: JoinKind,
159    },
160    /// Unlike MirRelationExpr, when `key` is empty AND `input` is empty this returns
161    /// a single row with the aggregates evaluated over empty groups, rather than returning zero
162    /// rows
163    Reduce {
164        input: Box<HirRelationExpr>,
165        group_key: Vec<usize>,
166        aggregates: Vec<AggregateExpr>,
167        expected_group_size: Option<u64>,
168    },
169    Distinct {
170        input: Box<HirRelationExpr>,
171    },
172    /// Groups and orders within each group, limiting output.
173    TopK {
174        /// The source collection.
175        input: Box<HirRelationExpr>,
176        /// Column indices used to form groups.
177        group_key: Vec<usize>,
178        /// Column indices used to order rows within groups.
179        order_key: Vec<ColumnOrder>,
180        /// Number of records to retain.
181        /// It is of SqlScalarType::Int64.
182        /// (UInt64 would make sense in theory: Then we wouldn't need to manually check
183        /// non-negativity, but would just get this for free when casting to UInt64. However, Int64
184        /// is better for Postgres compat. This is because if there is a $1 here, then when external
185        /// tools `describe` the prepared statement, they discover this type. If what they find
186        /// were UInt64, then they might have trouble calling the prepared statement, because the
187        /// unsigned types are non-standard, and also don't exist even in Postgres.)
188        limit: Option<HirScalarExpr>,
189        /// Number of records to skip.
190        /// It is of SqlScalarType::Int64.
191        /// This can contain parameters at first, but by the time we reach lowering, this should
192        /// already be simply a Literal.
193        offset: HirScalarExpr,
194        /// User-supplied hint: how many rows will have the same group key.
195        expected_group_size: Option<u64>,
196    },
197    Negate {
198        input: Box<HirRelationExpr>,
199    },
200    /// Keep rows from a dataflow where the row counts are positive.
201    Threshold {
202        input: Box<HirRelationExpr>,
203    },
204    Union {
205        base: Box<HirRelationExpr>,
206        inputs: Vec<HirRelationExpr>,
207    },
208}
209
210/// Stored column metadata.
211pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
212
213#[derive(
214    Debug,
215    Clone,
216    PartialEq,
217    Eq,
218    PartialOrd,
219    Ord,
220    Hash,
221    Serialize,
222    Deserialize
223)]
224/// Just like [`mz_expr::MirScalarExpr`], except where otherwise noted below.
225pub enum HirScalarExpr {
226    /// Unlike mz_expr::MirScalarExpr, we can nest HirRelationExprs via eg Exists. This means that a
227    /// variable could refer to a column of the current input, or to a column of an outer relation.
228    /// We use ColumnRef to denote the difference.
229    Column(ColumnRef, NameMetadata),
230    Parameter(usize, NameMetadata),
231    Literal(Row, SqlColumnType, NameMetadata),
232    CallUnmaterializable(UnmaterializableFunc, NameMetadata),
233    CallUnary {
234        func: UnaryFunc,
235        expr: Box<HirScalarExpr>,
236        name: NameMetadata,
237    },
238    CallBinary {
239        func: BinaryFunc,
240        expr1: Box<HirScalarExpr>,
241        expr2: Box<HirScalarExpr>,
242        name: NameMetadata,
243    },
244    CallVariadic {
245        func: VariadicFunc,
246        exprs: Vec<HirScalarExpr>,
247        name: NameMetadata,
248    },
249    If {
250        cond: Box<HirScalarExpr>,
251        then: Box<HirScalarExpr>,
252        els: Box<HirScalarExpr>,
253        name: NameMetadata,
254    },
255    /// Returns true if `expr` returns any rows
256    Exists(Box<HirRelationExpr>, NameMetadata),
257    /// Given `expr` with arity 1. If expr returns:
258    /// * 0 rows, return NULL
259    /// * 1 row, return the value of that row
260    /// * >1 rows, we return an error
261    Select(Box<HirRelationExpr>, NameMetadata),
262    Windowing(WindowExpr, NameMetadata),
263}
264
265#[derive(
266    Debug,
267    Clone,
268    PartialEq,
269    Eq,
270    PartialOrd,
271    Ord,
272    Hash,
273    Serialize,
274    Deserialize
275)]
276/// Represents the invocation of a window function over an optional partitioning with an optional
277/// order.
278pub struct WindowExpr {
279    pub func: WindowExprType,
280    pub partition_by: Vec<HirScalarExpr>,
281    /// ORDER BY is represented in a complicated way: `plan_function_order_by` gave us two things:
282    ///  - the `ColumnOrder`s we have put in the `order_by` fields in the `WindowExprType` in `func`
283    ///    above,
284    ///  - the `HirScalarExpr`s we have put in the following `order_by` field.
285    /// These are separated because they are used in different places: the outer `order_by` is used
286    /// in the lowering: based on it, we create a Row constructor that collects the scalar exprs;
287    /// the inner `order_by` is used in the rendering to actually execute the ordering on these Rows.
288    /// (`WindowExpr` exists only in HIR, but not in MIR.)
289    /// Note that the `column` field in the `ColumnOrder`s point into the Row constructed in the
290    /// lowering, and not to original input columns.
291    pub order_by: Vec<HirScalarExpr>,
292}
293
294impl WindowExpr {
295    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
296    where
297        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
298    {
299        #[allow(deprecated)]
300        self.func.visit_expressions(f)?;
301        for expr in self.partition_by.iter() {
302            f(expr)?;
303        }
304        for expr in self.order_by.iter() {
305            f(expr)?;
306        }
307        Ok(())
308    }
309
310    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
311    where
312        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
313    {
314        #[allow(deprecated)]
315        self.func.visit_expressions_mut(f)?;
316        for expr in self.partition_by.iter_mut() {
317            f(expr)?;
318        }
319        for expr in self.order_by.iter_mut() {
320            f(expr)?;
321        }
322        Ok(())
323    }
324}
325
326impl VisitChildren<HirScalarExpr> for WindowExpr {
327    fn visit_children<F>(&self, mut f: F)
328    where
329        F: FnMut(&HirScalarExpr),
330    {
331        self.func.visit_children(&mut f);
332        for expr in self.partition_by.iter() {
333            f(expr);
334        }
335        for expr in self.order_by.iter() {
336            f(expr);
337        }
338    }
339
340    fn visit_mut_children<F>(&mut self, mut f: F)
341    where
342        F: FnMut(&mut HirScalarExpr),
343    {
344        self.func.visit_mut_children(&mut f);
345        for expr in self.partition_by.iter_mut() {
346            f(expr);
347        }
348        for expr in self.order_by.iter_mut() {
349            f(expr);
350        }
351    }
352
353    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
354    where
355        F: FnMut(&HirScalarExpr) -> Result<(), E>,
356        E: From<RecursionLimitError>,
357    {
358        self.func.try_visit_children(&mut f)?;
359        for expr in self.partition_by.iter() {
360            f(expr)?;
361        }
362        for expr in self.order_by.iter() {
363            f(expr)?;
364        }
365        Ok(())
366    }
367
368    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
369    where
370        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
371        E: From<RecursionLimitError>,
372    {
373        self.func.try_visit_mut_children(&mut f)?;
374        for expr in self.partition_by.iter_mut() {
375            f(expr)?;
376        }
377        for expr in self.order_by.iter_mut() {
378            f(expr)?;
379        }
380        Ok(())
381    }
382}
383
384#[derive(
385    Debug,
386    Clone,
387    PartialEq,
388    Eq,
389    PartialOrd,
390    Ord,
391    Hash,
392    Serialize,
393    Deserialize
394)]
395/// A window function with its parameters.
396///
397/// There are three types of window functions:
398/// - scalar window functions, which return a different scalar value for each
399///   row within a partition that depends exclusively on the position of the row
400///   within the partition;
401/// - value window functions, which return a scalar value for each row within a
402///   partition that might be computed based on a single row, which is usually not
403///   the current row (e.g., previous or following row; first or last row of the
404///   partition);
405/// - aggregate window functions, which compute a traditional aggregation as a
406///   window function (e.g. `sum(x) OVER (...)`).
407///   (Aggregate window  functions can in some cases be computed by joining the
408///   input relation with a reduction over the same relation that computes the
409///   aggregation using the partition key as its grouping key, but we don't
410///   automatically do this currently.)
411pub enum WindowExprType {
412    Scalar(ScalarWindowExpr),
413    Value(ValueWindowExpr),
414    Aggregate(AggregateWindowExpr),
415}
416
417impl WindowExprType {
418    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
419    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
420    where
421        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
422    {
423        #[allow(deprecated)]
424        match self {
425            Self::Scalar(expr) => expr.visit_expressions(f),
426            Self::Value(expr) => expr.visit_expressions(f),
427            Self::Aggregate(expr) => expr.visit_expressions(f),
428        }
429    }
430
431    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
432    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
433    where
434        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
435    {
436        #[allow(deprecated)]
437        match self {
438            Self::Scalar(expr) => expr.visit_expressions_mut(f),
439            Self::Value(expr) => expr.visit_expressions_mut(f),
440            Self::Aggregate(expr) => expr.visit_expressions_mut(f),
441        }
442    }
443
444    fn typ(
445        &self,
446        outers: &[SqlRelationType],
447        inner: &SqlRelationType,
448        params: &BTreeMap<usize, SqlScalarType>,
449    ) -> SqlColumnType {
450        match self {
451            Self::Scalar(expr) => expr.typ(outers, inner, params),
452            Self::Value(expr) => expr.typ(outers, inner, params),
453            Self::Aggregate(expr) => expr.typ(outers, inner, params),
454        }
455    }
456}
457
458impl VisitChildren<HirScalarExpr> for WindowExprType {
459    fn visit_children<F>(&self, f: F)
460    where
461        F: FnMut(&HirScalarExpr),
462    {
463        match self {
464            Self::Scalar(_) => (),
465            Self::Value(expr) => expr.visit_children(f),
466            Self::Aggregate(expr) => expr.visit_children(f),
467        }
468    }
469
470    fn visit_mut_children<F>(&mut self, f: F)
471    where
472        F: FnMut(&mut HirScalarExpr),
473    {
474        match self {
475            Self::Scalar(_) => (),
476            Self::Value(expr) => expr.visit_mut_children(f),
477            Self::Aggregate(expr) => expr.visit_mut_children(f),
478        }
479    }
480
481    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
482    where
483        F: FnMut(&HirScalarExpr) -> Result<(), E>,
484        E: From<RecursionLimitError>,
485    {
486        match self {
487            Self::Scalar(_) => Ok(()),
488            Self::Value(expr) => expr.try_visit_children(f),
489            Self::Aggregate(expr) => expr.try_visit_children(f),
490        }
491    }
492
493    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
494    where
495        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
496        E: From<RecursionLimitError>,
497    {
498        match self {
499            Self::Scalar(_) => Ok(()),
500            Self::Value(expr) => expr.try_visit_mut_children(f),
501            Self::Aggregate(expr) => expr.try_visit_mut_children(f),
502        }
503    }
504}
505
506#[derive(
507    Debug,
508    Clone,
509    PartialEq,
510    Eq,
511    PartialOrd,
512    Ord,
513    Hash,
514    Serialize,
515    Deserialize
516)]
517pub struct ScalarWindowExpr {
518    pub func: ScalarWindowFunc,
519    pub order_by: Vec<ColumnOrder>,
520}
521
522impl ScalarWindowExpr {
523    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
524    pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
525    where
526        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
527    {
528        match self.func {
529            ScalarWindowFunc::RowNumber => {}
530            ScalarWindowFunc::Rank => {}
531            ScalarWindowFunc::DenseRank => {}
532        }
533        Ok(())
534    }
535
536    #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
537    pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
538    where
539        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
540    {
541        match self.func {
542            ScalarWindowFunc::RowNumber => {}
543            ScalarWindowFunc::Rank => {}
544            ScalarWindowFunc::DenseRank => {}
545        }
546        Ok(())
547    }
548
549    fn typ(
550        &self,
551        _outers: &[SqlRelationType],
552        _inner: &SqlRelationType,
553        _params: &BTreeMap<usize, SqlScalarType>,
554    ) -> SqlColumnType {
555        self.func.output_sql_type()
556    }
557
558    pub fn into_expr(self) -> mz_expr::AggregateFunc {
559        match self.func {
560            ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
561                order_by: self.order_by,
562            },
563            ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
564                order_by: self.order_by,
565            },
566            ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
567                order_by: self.order_by,
568            },
569        }
570    }
571}
572
573#[derive(
574    Debug,
575    Clone,
576    PartialEq,
577    Eq,
578    PartialOrd,
579    Ord,
580    Hash,
581    Serialize,
582    Deserialize
583)]
584/// Scalar Window functions
585pub enum ScalarWindowFunc {
586    RowNumber,
587    Rank,
588    DenseRank,
589}
590
591impl Display for ScalarWindowFunc {
592    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
593        match self {
594            ScalarWindowFunc::RowNumber => write!(f, "row_number"),
595            ScalarWindowFunc::Rank => write!(f, "rank"),
596            ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
597        }
598    }
599}
600
601impl ScalarWindowFunc {
602    pub fn output_sql_type(&self) -> SqlColumnType {
603        match self {
604            ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
605            ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
606            ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
607        }
608    }
609}
610
611#[derive(
612    Debug,
613    Clone,
614    PartialEq,
615    Eq,
616    PartialOrd,
617    Ord,
618    Hash,
619    Serialize,
620    Deserialize
621)]
622pub struct ValueWindowExpr {
623    pub func: ValueWindowFunc,
624    /// If the argument list has a single element (e.g., for `first_value`), then it's that element.
625    /// If the argument list has multiple elements (e.g., for `lag`), then it's encoded in a record,
626    /// e.g., `row(#1, 3, null)`.
627    /// If it's a fused window function, then the arguments of each of the constituent function
628    /// calls are wrapped in an outer record.
629    pub args: Box<HirScalarExpr>,
630    /// See comment on `WindowExpr::order_by`.
631    pub order_by: Vec<ColumnOrder>,
632    pub window_frame: WindowFrame,
633    pub ignore_nulls: bool,
634}
635
636impl Display for ValueWindowFunc {
637    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
638        match self {
639            ValueWindowFunc::Lag => write!(f, "lag"),
640            ValueWindowFunc::Lead => write!(f, "lead"),
641            ValueWindowFunc::FirstValue => write!(f, "first_value"),
642            ValueWindowFunc::LastValue => write!(f, "last_value"),
643            ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
644        }
645    }
646}
647
648impl ValueWindowExpr {
649    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
650    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
651    where
652        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
653    {
654        f(&self.args)
655    }
656
657    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
658    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
659    where
660        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
661    {
662        f(&mut self.args)
663    }
664
665    fn typ(
666        &self,
667        outers: &[SqlRelationType],
668        inner: &SqlRelationType,
669        params: &BTreeMap<usize, SqlScalarType>,
670    ) -> SqlColumnType {
671        self.func
672            .output_sql_type(self.args.typ(outers, inner, params))
673    }
674
675    /// Converts into `mz_expr::AggregateFunc`.
676    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
677        (
678            self.args,
679            self.func
680                .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
681        )
682    }
683}
684
685impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
686    fn visit_children<F>(&self, mut f: F)
687    where
688        F: FnMut(&HirScalarExpr),
689    {
690        f(&self.args)
691    }
692
693    fn visit_mut_children<F>(&mut self, mut f: F)
694    where
695        F: FnMut(&mut HirScalarExpr),
696    {
697        f(&mut self.args)
698    }
699
700    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
701    where
702        F: FnMut(&HirScalarExpr) -> Result<(), E>,
703        E: From<RecursionLimitError>,
704    {
705        f(&self.args)
706    }
707
708    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
709    where
710        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
711        E: From<RecursionLimitError>,
712    {
713        f(&mut self.args)
714    }
715}
716
717#[derive(
718    Debug,
719    Clone,
720    PartialEq,
721    Eq,
722    PartialOrd,
723    Ord,
724    Hash,
725    Serialize,
726    Deserialize
727)]
728/// Value Window functions
729pub enum ValueWindowFunc {
730    Lag,
731    Lead,
732    FirstValue,
733    LastValue,
734    Fused(Vec<ValueWindowFunc>),
735}
736
737impl ValueWindowFunc {
738    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
739        match self {
740            ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
741                // The input is a (value, offset, default) record, so extract the type of the first arg
742                input_type.scalar_type.unwrap_record_element_type()[0]
743                    .clone()
744                    .nullable(true)
745            }
746            ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
747                input_type.scalar_type.nullable(true)
748            }
749            ValueWindowFunc::Fused(funcs) => {
750                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
751                SqlScalarType::Record {
752                    fields: funcs
753                        .iter()
754                        .zip_eq(input_types)
755                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
756                        .collect(),
757                    custom_id: None,
758                }
759                .nullable(false)
760            }
761        }
762    }
763
764    pub fn into_expr(
765        self,
766        order_by: Vec<ColumnOrder>,
767        window_frame: WindowFrame,
768        ignore_nulls: bool,
769    ) -> mz_expr::AggregateFunc {
770        match self {
771            // Lag and Lead are fundamentally the same function, just with opposite directions
772            ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
773                order_by,
774                lag_lead: mz_expr::LagLeadType::Lag,
775                ignore_nulls,
776            },
777            ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
778                order_by,
779                lag_lead: mz_expr::LagLeadType::Lead,
780                ignore_nulls,
781            },
782            ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
783                order_by,
784                window_frame,
785            },
786            ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
787                order_by,
788                window_frame,
789            },
790            ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
791                funcs: funcs
792                    .into_iter()
793                    .map(|func| {
794                        func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
795                    })
796                    .collect(),
797                order_by,
798            },
799        }
800    }
801}
802
803#[derive(
804    Debug,
805    Clone,
806    PartialEq,
807    Eq,
808    PartialOrd,
809    Ord,
810    Hash,
811    Serialize,
812    Deserialize
813)]
814pub struct AggregateWindowExpr {
815    pub aggregate_expr: AggregateExpr,
816    pub order_by: Vec<ColumnOrder>,
817    pub window_frame: WindowFrame,
818}
819
820impl AggregateWindowExpr {
821    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
822    pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
823    where
824        F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
825    {
826        f(&self.aggregate_expr.expr)
827    }
828
829    #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
830    pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
831    where
832        F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
833    {
834        f(&mut self.aggregate_expr.expr)
835    }
836
837    fn typ(
838        &self,
839        outers: &[SqlRelationType],
840        inner: &SqlRelationType,
841        params: &BTreeMap<usize, SqlScalarType>,
842    ) -> SqlColumnType {
843        self.aggregate_expr
844            .func
845            .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params))
846    }
847
848    pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
849        if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
850            (
851                self.aggregate_expr.expr,
852                FusedWindowAggregate {
853                    wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
854                    order_by: self.order_by,
855                    window_frame: self.window_frame,
856                },
857            )
858        } else {
859            (
860                self.aggregate_expr.expr,
861                WindowAggregate {
862                    wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
863                    order_by: self.order_by,
864                    window_frame: self.window_frame,
865                },
866            )
867        }
868    }
869}
870
871impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
872    fn visit_children<F>(&self, mut f: F)
873    where
874        F: FnMut(&HirScalarExpr),
875    {
876        f(&self.aggregate_expr.expr)
877    }
878
879    fn visit_mut_children<F>(&mut self, mut f: F)
880    where
881        F: FnMut(&mut HirScalarExpr),
882    {
883        f(&mut self.aggregate_expr.expr)
884    }
885
886    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
887    where
888        F: FnMut(&HirScalarExpr) -> Result<(), E>,
889        E: From<RecursionLimitError>,
890    {
891        f(&self.aggregate_expr.expr)
892    }
893
894    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
895    where
896        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
897        E: From<RecursionLimitError>,
898    {
899        f(&mut self.aggregate_expr.expr)
900    }
901}
902
903/// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully
904/// determined. Several SQL expressions can be freely coerced based upon where
905/// in the expression tree they appear. For example, the string literal '42'
906/// will be automatically coerced to the integer 42 if used in a numeric
907/// context:
908///
909/// ```sql
910/// SELECT '42' + 42
911/// ```
912///
913/// This separate type gives the code that needs to interact with coercions very
914/// fine-grained control over what coercions happen and when.
915///
916/// The primary driver of coercion is function and operator selection, as
917/// choosing the correct function or operator implementation depends on the type
918/// of the provided arguments. Coercion also occurs at the very root of the
919/// scalar expression tree. For example in
920///
921/// ```sql
922/// SELECT ... WHERE $1
923/// ```
924///
925/// the `WHERE` clause will coerce the contained unconstrained type parameter
926/// `$1` to have type bool.
927#[derive(Clone, Debug)]
928pub enum CoercibleScalarExpr {
929    Coerced(HirScalarExpr),
930    Parameter(usize),
931    LiteralNull,
932    LiteralString(String),
933    LiteralRecord(Vec<CoercibleScalarExpr>),
934}
935
936impl CoercibleScalarExpr {
937    pub fn type_as(
938        self,
939        ecx: &ExprContext,
940        ty: &SqlScalarType,
941    ) -> Result<HirScalarExpr, PlanError> {
942        let expr = typeconv::plan_coerce(ecx, self, ty)?;
943        let expr_ty = ecx.scalar_type(&expr);
944        if ty != &expr_ty {
945            sql_bail!(
946                "{} must have type {}, not type {}",
947                ecx.name,
948                ecx.humanize_sql_scalar_type(ty, false),
949                ecx.humanize_sql_scalar_type(&expr_ty, false),
950            );
951        }
952        Ok(expr)
953    }
954
955    pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
956        typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
957    }
958
959    pub fn cast_to(
960        self,
961        ecx: &ExprContext,
962        ccx: CastContext,
963        ty: &SqlScalarType,
964    ) -> Result<HirScalarExpr, PlanError> {
965        let expr = typeconv::plan_coerce(ecx, self, ty)?;
966        typeconv::plan_cast(ecx, ccx, expr, ty)
967    }
968}
969
970/// The column type for a [`CoercibleScalarExpr`].
971#[derive(Clone, Debug)]
972pub enum CoercibleColumnType {
973    Coerced(SqlColumnType),
974    Record(Vec<CoercibleColumnType>),
975    Uncoerced,
976}
977
978impl CoercibleColumnType {
979    /// Reports the nullability of the type.
980    pub fn nullable(&self) -> bool {
981        match self {
982            // A coerced value's nullability is known.
983            CoercibleColumnType::Coerced(ct) => ct.nullable,
984
985            // A literal record can never be null.
986            CoercibleColumnType::Record(_) => false,
987
988            // An uncoerced literal may be the literal `NULL`, so we have
989            // to conservatively assume it is nullable.
990            CoercibleColumnType::Uncoerced => true,
991        }
992    }
993}
994
995/// The scalar type for a [`CoercibleScalarExpr`].
996#[derive(Clone, Debug)]
997pub enum CoercibleScalarType {
998    Coerced(SqlScalarType),
999    Record(Vec<CoercibleColumnType>),
1000    Uncoerced,
1001}
1002
1003impl CoercibleScalarType {
1004    /// Reports whether the scalar type has been coerced.
1005    pub fn is_coerced(&self) -> bool {
1006        matches!(self, CoercibleScalarType::Coerced(_))
1007    }
1008
1009    /// Returns the coerced scalar type, if the type is coerced.
1010    pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1011        match self {
1012            CoercibleScalarType::Coerced(t) => Some(t),
1013            _ => None,
1014        }
1015    }
1016
1017    /// If the type is coerced, apply the mapping function to the contained
1018    /// scalar type.
1019    pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1020    where
1021        F: FnOnce(SqlScalarType) -> SqlScalarType,
1022    {
1023        match self {
1024            CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1025            _ => self,
1026        }
1027    }
1028
1029    /// If the type is an coercible record, forcibly converts to a coerced
1030    /// record type. Any uncoerced field types are assumed to be of type text.
1031    ///
1032    /// Generally you should prefer to use [`typeconv::plan_coerce`], which
1033    /// accepts a type hint that can indicate the types of uncoerced field
1034    /// types.
1035    pub fn force_coerced_if_record(&mut self) {
1036        fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1037            let mut fields = vec![];
1038            for (i, uf) in uncoerced_fields.enumerate() {
1039                let name = ColumnName::from(format!("f{}", i + 1));
1040                let ty = match uf {
1041                    CoercibleColumnType::Coerced(ty) => ty,
1042                    CoercibleColumnType::Record(mut fields) => {
1043                        convert(fields.drain(..)).nullable(false)
1044                    }
1045                    CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1046                };
1047                fields.push((name, ty))
1048            }
1049            SqlScalarType::Record {
1050                fields: fields.into(),
1051                custom_id: None,
1052            }
1053        }
1054
1055        if let CoercibleScalarType::Record(fields) = self {
1056            *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1057        }
1058    }
1059}
1060
1061/// An expression whose type can be ascertained.
1062///
1063/// Abstracts over `ScalarExpr` and `CoercibleScalarExpr`.
1064pub trait AbstractExpr {
1065    type Type: AbstractColumnType;
1066
1067    /// Computes the type of the expression.
1068    fn typ(
1069        &self,
1070        outers: &[SqlRelationType],
1071        inner: &SqlRelationType,
1072        params: &BTreeMap<usize, SqlScalarType>,
1073    ) -> Self::Type;
1074}
1075
1076impl AbstractExpr for CoercibleScalarExpr {
1077    type Type = CoercibleColumnType;
1078
1079    fn typ(
1080        &self,
1081        outers: &[SqlRelationType],
1082        inner: &SqlRelationType,
1083        params: &BTreeMap<usize, SqlScalarType>,
1084    ) -> Self::Type {
1085        match self {
1086            CoercibleScalarExpr::Coerced(expr) => {
1087                CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1088            }
1089            CoercibleScalarExpr::LiteralRecord(scalars) => {
1090                let fields = scalars
1091                    .iter()
1092                    .map(|s| s.typ(outers, inner, params))
1093                    .collect();
1094                CoercibleColumnType::Record(fields)
1095            }
1096            _ => CoercibleColumnType::Uncoerced,
1097        }
1098    }
1099}
1100
1101/// A column type-like object whose underlying scalar type-like object can be
1102/// ascertained.
1103///
1104/// Abstracts over `SqlColumnType` and `CoercibleColumnType`.
1105pub trait AbstractColumnType {
1106    type AbstractScalarType;
1107
1108    /// Converts the column type-like object into its inner scalar type-like
1109    /// object.
1110    fn scalar_type(self) -> Self::AbstractScalarType;
1111}
1112
1113impl AbstractColumnType for SqlColumnType {
1114    type AbstractScalarType = SqlScalarType;
1115
1116    fn scalar_type(self) -> Self::AbstractScalarType {
1117        self.scalar_type
1118    }
1119}
1120
1121impl AbstractColumnType for CoercibleColumnType {
1122    type AbstractScalarType = CoercibleScalarType;
1123
1124    fn scalar_type(self) -> Self::AbstractScalarType {
1125        match self {
1126            CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1127            CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1128            CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1129        }
1130    }
1131}
1132
1133impl From<HirScalarExpr> for CoercibleScalarExpr {
1134    fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1135        CoercibleScalarExpr::Coerced(expr)
1136    }
1137}
1138
1139/// A leveled column reference.
1140///
1141/// In the course of decorrelation, multiple levels of nested subqueries are
1142/// traversed, and references to columns may correspond to different levels
1143/// of containing outer subqueries.
1144///
1145/// A `ColumnRef` allows expressions to refer to columns while being clear
1146/// about which level the column references without manually performing the
1147/// bookkeeping tracking their actual column locations.
1148///
1149/// Specifically, a `ColumnRef` refers to a column `level` subquery level *out*
1150/// from the reference, using `column` as a unique identifier in that subquery level.
1151/// A `level` of zero corresponds to the current scope, and levels increase to
1152/// indicate subqueries further "outwards".
1153#[derive(
1154    Debug,
1155    Clone,
1156    Copy,
1157    PartialEq,
1158    Eq,
1159    Hash,
1160    Ord,
1161    PartialOrd,
1162    Serialize,
1163    Deserialize
1164)]
1165pub struct ColumnRef {
1166    // scope level, where 0 is the current scope and 1+ are outer scopes.
1167    pub level: usize,
1168    // level-local column identifier used.
1169    pub column: usize,
1170}
1171
1172#[derive(
1173    Debug,
1174    Clone,
1175    PartialEq,
1176    Eq,
1177    PartialOrd,
1178    Ord,
1179    Hash,
1180    Serialize,
1181    Deserialize
1182)]
1183pub enum JoinKind {
1184    Inner,
1185    LeftOuter,
1186    RightOuter,
1187    FullOuter,
1188}
1189
1190impl fmt::Display for JoinKind {
1191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1192        write!(
1193            f,
1194            "{}",
1195            match self {
1196                JoinKind::Inner => "Inner",
1197                JoinKind::LeftOuter => "LeftOuter",
1198                JoinKind::RightOuter => "RightOuter",
1199                JoinKind::FullOuter => "FullOuter",
1200            }
1201        )
1202    }
1203}
1204
1205impl JoinKind {
1206    pub fn can_be_correlated(&self) -> bool {
1207        match self {
1208            JoinKind::Inner | JoinKind::LeftOuter => true,
1209            JoinKind::RightOuter | JoinKind::FullOuter => false,
1210        }
1211    }
1212
1213    pub fn can_elide_identity_left_join(&self) -> bool {
1214        match self {
1215            JoinKind::Inner | JoinKind::RightOuter => true,
1216            JoinKind::LeftOuter | JoinKind::FullOuter => false,
1217        }
1218    }
1219
1220    pub fn can_elide_identity_right_join(&self) -> bool {
1221        match self {
1222            JoinKind::Inner | JoinKind::LeftOuter => true,
1223            JoinKind::RightOuter | JoinKind::FullOuter => false,
1224        }
1225    }
1226}
1227
1228#[derive(
1229    Debug,
1230    Clone,
1231    PartialEq,
1232    Eq,
1233    PartialOrd,
1234    Ord,
1235    Hash,
1236    Serialize,
1237    Deserialize
1238)]
1239pub struct AggregateExpr {
1240    pub func: AggregateFunc,
1241    pub expr: Box<HirScalarExpr>,
1242    pub distinct: bool,
1243}
1244
1245/// Aggregate functions analogous to `mz_expr::AggregateFunc`, but whose
1246/// types may be different.
1247///
1248/// Specifically, the nullability of the aggregate columns is more common
1249/// here than in `expr`, as these aggregates may be applied over empty
1250/// result sets and should be null in those cases, whereas `expr` variants
1251/// only return null values when supplied nulls as input.
1252#[derive(
1253    Clone,
1254    Debug,
1255    Eq,
1256    PartialEq,
1257    PartialOrd,
1258    Ord,
1259    Hash,
1260    Serialize,
1261    Deserialize
1262)]
1263pub enum AggregateFunc {
1264    MaxNumeric,
1265    MaxInt16,
1266    MaxInt32,
1267    MaxInt64,
1268    MaxUInt16,
1269    MaxUInt32,
1270    MaxUInt64,
1271    MaxMzTimestamp,
1272    MaxFloat32,
1273    MaxFloat64,
1274    MaxBool,
1275    MaxString,
1276    MaxDate,
1277    MaxTimestamp,
1278    MaxTimestampTz,
1279    MaxInterval,
1280    MaxTime,
1281    MinNumeric,
1282    MinInt16,
1283    MinInt32,
1284    MinInt64,
1285    MinUInt16,
1286    MinUInt32,
1287    MinUInt64,
1288    MinMzTimestamp,
1289    MinFloat32,
1290    MinFloat64,
1291    MinBool,
1292    MinString,
1293    MinDate,
1294    MinTimestamp,
1295    MinTimestampTz,
1296    MinInterval,
1297    MinTime,
1298    SumInt16,
1299    SumInt32,
1300    SumInt64,
1301    SumUInt16,
1302    SumUInt32,
1303    SumUInt64,
1304    SumFloat32,
1305    SumFloat64,
1306    SumNumeric,
1307    Count,
1308    Any,
1309    All,
1310    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1311    /// into a JSON list. The other elements are columns used by `order_by`.
1312    ///
1313    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1314    /// layer, this function filters out `Datum::Null`, for consistency with
1315    /// the other aggregate functions.
1316    JsonbAgg {
1317        order_by: Vec<ColumnOrder>,
1318    },
1319    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1320    /// JSON map. The other elements are columns used by `order_by`.
1321    JsonbObjectAgg {
1322        order_by: Vec<ColumnOrder>,
1323    },
1324    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1325    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1326    /// elements are columns used by `order_by`.
1327    MapAgg {
1328        order_by: Vec<ColumnOrder>,
1329        value_type: SqlScalarType,
1330    },
1331    /// Accumulates `Datum::List`s whose first element is a `Datum::Array` into a
1332    /// single `Datum::Array`. The other elements are columns used by `order_by`.
1333    ArrayConcat {
1334        order_by: Vec<ColumnOrder>,
1335    },
1336    /// Accumulates `Datum::List`s whose first element is a `Datum::List` into a
1337    /// single `Datum::List`. The other elements are columns used by `order_by`.
1338    ListConcat {
1339        order_by: Vec<ColumnOrder>,
1340    },
1341    StringAgg {
1342        order_by: Vec<ColumnOrder>,
1343    },
1344    /// A bundle of fused window aggregations: its input is a record, whose each
1345    /// component will be the input to one of the `AggregateFunc`s.
1346    ///
1347    /// Importantly, this aggregation can only be present inside a `WindowExpr`,
1348    /// more specifically an `AggregateWindowExpr`.
1349    FusedWindowAgg {
1350        funcs: Vec<AggregateFunc>,
1351    },
1352    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
1353    ///
1354    /// Useful for removing an expensive aggregation while maintaining the shape
1355    /// of a reduce operator.
1356    Dummy,
1357}
1358
1359impl AggregateFunc {
1360    /// Converts the `sql::AggregateFunc` to a corresponding `mz_expr::AggregateFunc`.
1361    pub fn into_expr(self) -> mz_expr::AggregateFunc {
1362        match self {
1363            AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1364            AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1365            AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1366            AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1367            AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1368            AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1369            AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1370            AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1371            AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1372            AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1373            AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1374            AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1375            AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1376            AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1377            AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1378            AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1379            AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1380            AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1381            AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1382            AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1383            AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1384            AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1385            AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1386            AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1387            AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1388            AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1389            AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1390            AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1391            AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1392            AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1393            AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1394            AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1395            AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1396            AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1397            AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1398            AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1399            AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1400            AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1401            AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1402            AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1403            AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1404            AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1405            AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1406            AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1407            AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1408            AggregateFunc::All => mz_expr::AggregateFunc::All,
1409            AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1410            AggregateFunc::JsonbObjectAgg { order_by } => {
1411                mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1412            }
1413            AggregateFunc::MapAgg {
1414                order_by,
1415                value_type,
1416            } => mz_expr::AggregateFunc::MapAgg {
1417                order_by,
1418                value_type,
1419            },
1420            AggregateFunc::ArrayConcat { order_by } => {
1421                mz_expr::AggregateFunc::ArrayConcat { order_by }
1422            }
1423            AggregateFunc::ListConcat { order_by } => {
1424                mz_expr::AggregateFunc::ListConcat { order_by }
1425            }
1426            AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1427            // `AggregateFunc::FusedWindowAgg` should be specially handled in
1428            // `AggregateWindowExpr::into_expr`.
1429            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1430                panic!("into_expr called on FusedWindowAgg")
1431            }
1432            AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1433        }
1434    }
1435
1436    /// Returns a datum whose inclusion in the aggregation will not change its
1437    /// result.
1438    ///
1439    /// # Panics
1440    ///
1441    /// Panics if called on a `FusedWindowAgg`.
1442    pub fn identity_datum(&self) -> Datum<'static> {
1443        match self {
1444            AggregateFunc::Any => Datum::False,
1445            AggregateFunc::All => Datum::True,
1446            AggregateFunc::Dummy => Datum::Dummy,
1447            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1448            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1449            AggregateFunc::MaxNumeric
1450            | AggregateFunc::MaxInt16
1451            | AggregateFunc::MaxInt32
1452            | AggregateFunc::MaxInt64
1453            | AggregateFunc::MaxUInt16
1454            | AggregateFunc::MaxUInt32
1455            | AggregateFunc::MaxUInt64
1456            | AggregateFunc::MaxMzTimestamp
1457            | AggregateFunc::MaxFloat32
1458            | AggregateFunc::MaxFloat64
1459            | AggregateFunc::MaxBool
1460            | AggregateFunc::MaxString
1461            | AggregateFunc::MaxDate
1462            | AggregateFunc::MaxTimestamp
1463            | AggregateFunc::MaxTimestampTz
1464            | AggregateFunc::MaxInterval
1465            | AggregateFunc::MaxTime
1466            | AggregateFunc::MinNumeric
1467            | AggregateFunc::MinInt16
1468            | AggregateFunc::MinInt32
1469            | AggregateFunc::MinInt64
1470            | AggregateFunc::MinUInt16
1471            | AggregateFunc::MinUInt32
1472            | AggregateFunc::MinUInt64
1473            | AggregateFunc::MinMzTimestamp
1474            | AggregateFunc::MinFloat32
1475            | AggregateFunc::MinFloat64
1476            | AggregateFunc::MinBool
1477            | AggregateFunc::MinString
1478            | AggregateFunc::MinDate
1479            | AggregateFunc::MinTimestamp
1480            | AggregateFunc::MinTimestampTz
1481            | AggregateFunc::MinInterval
1482            | AggregateFunc::MinTime
1483            | AggregateFunc::SumInt16
1484            | AggregateFunc::SumInt32
1485            | AggregateFunc::SumInt64
1486            | AggregateFunc::SumUInt16
1487            | AggregateFunc::SumUInt32
1488            | AggregateFunc::SumUInt64
1489            | AggregateFunc::SumFloat32
1490            | AggregateFunc::SumFloat64
1491            | AggregateFunc::SumNumeric
1492            | AggregateFunc::Count
1493            | AggregateFunc::JsonbAgg { .. }
1494            | AggregateFunc::JsonbObjectAgg { .. }
1495            | AggregateFunc::MapAgg { .. }
1496            | AggregateFunc::StringAgg { .. } => Datum::Null,
1497            AggregateFunc::FusedWindowAgg { funcs: _ } => {
1498                // `identity_datum` is used only in HIR planning, and `FusedWindowAgg` can't occur
1499                // in HIR planning, because it is introduced only during HIR transformation.
1500                //
1501                // The implementation could be something like the following, except that we need to
1502                // return a `Datum<'static>`, so we can't actually dynamically compute this.
1503                // ```
1504                // let temp_storage = RowArena::new();
1505                // temp_storage.make_datum(|packer| packer.push_list(funcs.iter().map(|f| f.identity_datum())))
1506                // ```
1507                panic!("FusedWindowAgg doesn't have an identity_datum")
1508            }
1509        }
1510    }
1511
1512    /// The output column type for the result of an aggregation.
1513    ///
1514    /// The output column type also contains nullability information, which
1515    /// is (without further information) true for aggregations that are not
1516    /// counts.
1517    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1518        let scalar_type = match self {
1519            AggregateFunc::Count => SqlScalarType::Int64,
1520            AggregateFunc::Any => SqlScalarType::Bool,
1521            AggregateFunc::All => SqlScalarType::Bool,
1522            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1523            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1524            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1525            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1526            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1527                max_scale: Some(NumericMaxScale::ZERO),
1528            },
1529            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1530            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1531                max_scale: Some(NumericMaxScale::ZERO),
1532            },
1533            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1534                value_type: Box::new(value_type.clone()),
1535                custom_id: None,
1536            },
1537            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1538                match input_type.scalar_type {
1539                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
1540                    SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1541                    _ => unreachable!(),
1542                }
1543            }
1544            AggregateFunc::MaxNumeric
1545            | AggregateFunc::MaxInt16
1546            | AggregateFunc::MaxInt32
1547            | AggregateFunc::MaxInt64
1548            | AggregateFunc::MaxUInt16
1549            | AggregateFunc::MaxUInt32
1550            | AggregateFunc::MaxUInt64
1551            | AggregateFunc::MaxMzTimestamp
1552            | AggregateFunc::MaxFloat32
1553            | AggregateFunc::MaxFloat64
1554            | AggregateFunc::MaxBool
1555            | AggregateFunc::MaxString
1556            | AggregateFunc::MaxDate
1557            | AggregateFunc::MaxTimestamp
1558            | AggregateFunc::MaxTimestampTz
1559            | AggregateFunc::MaxInterval
1560            | AggregateFunc::MaxTime
1561            | AggregateFunc::MinNumeric
1562            | AggregateFunc::MinInt16
1563            | AggregateFunc::MinInt32
1564            | AggregateFunc::MinInt64
1565            | AggregateFunc::MinUInt16
1566            | AggregateFunc::MinUInt32
1567            | AggregateFunc::MinUInt64
1568            | AggregateFunc::MinMzTimestamp
1569            | AggregateFunc::MinFloat32
1570            | AggregateFunc::MinFloat64
1571            | AggregateFunc::MinBool
1572            | AggregateFunc::MinString
1573            | AggregateFunc::MinDate
1574            | AggregateFunc::MinTimestamp
1575            | AggregateFunc::MinTimestampTz
1576            | AggregateFunc::MinInterval
1577            | AggregateFunc::MinTime
1578            | AggregateFunc::SumFloat32
1579            | AggregateFunc::SumFloat64
1580            | AggregateFunc::SumNumeric
1581            | AggregateFunc::Dummy => input_type.scalar_type,
1582            AggregateFunc::FusedWindowAgg { funcs } => {
1583                let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1584                SqlScalarType::Record {
1585                    fields: funcs
1586                        .iter()
1587                        .zip_eq(input_types)
1588                        .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
1589                        .collect(),
1590                    custom_id: None,
1591                }
1592            }
1593        };
1594        // max/min/sum return null on empty sets
1595        let nullable = !matches!(self, AggregateFunc::Count);
1596        scalar_type.nullable(nullable)
1597    }
1598
1599    pub fn is_order_sensitive(&self) -> bool {
1600        use AggregateFunc::*;
1601        matches!(
1602            self,
1603            JsonbAgg { .. }
1604                | JsonbObjectAgg { .. }
1605                | MapAgg { .. }
1606                | ArrayConcat { .. }
1607                | ListConcat { .. }
1608                | StringAgg { .. }
1609        )
1610    }
1611}
1612
1613impl HirRelationExpr {
1614    /// Gets the SQL type of a self-contained, top-level expression.
1615    pub fn top_level_typ(&self) -> SqlRelationType {
1616        self.typ(&[], &BTreeMap::new())
1617    }
1618
1619    /// Gets the SQL type of the expression.
1620    ///
1621    /// `outers` gives types for outer relations.
1622    /// `params` gives types for parameters.
1623    pub fn typ(
1624        &self,
1625        outers: &[SqlRelationType],
1626        params: &BTreeMap<usize, SqlScalarType>,
1627    ) -> SqlRelationType {
1628        stack::maybe_grow(|| match self {
1629            HirRelationExpr::Constant { typ, .. } => typ.clone(),
1630            HirRelationExpr::Get { typ, .. } => typ.clone(),
1631            HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1632            HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1633            HirRelationExpr::Project { input, outputs } => {
1634                let input_typ = input.typ(outers, params);
1635                SqlRelationType::new(
1636                    outputs
1637                        .iter()
1638                        .map(|&i| input_typ.column_types[i].clone())
1639                        .collect(),
1640                )
1641            }
1642            HirRelationExpr::Map { input, scalars } => {
1643                let mut typ = input.typ(outers, params);
1644                for scalar in scalars {
1645                    typ.column_types.push(scalar.typ(outers, &typ, params));
1646                }
1647                typ
1648            }
1649            HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(),
1650            HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1651                input.typ(outers, params)
1652            }
1653            HirRelationExpr::Join {
1654                left, right, kind, ..
1655            } => {
1656                let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1657                let right_nullable =
1658                    matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1659                let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1660                    let nullable = t.nullable || left_nullable;
1661                    t.nullable(nullable)
1662                });
1663                let mut outers = outers.to_vec();
1664                outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1665                let rt = right
1666                    .typ(&outers, params)
1667                    .column_types
1668                    .into_iter()
1669                    .map(|t| {
1670                        let nullable = t.nullable || right_nullable;
1671                        t.nullable(nullable)
1672                    });
1673                SqlRelationType::new(lt.chain(rt).collect())
1674            }
1675            HirRelationExpr::Reduce {
1676                input,
1677                group_key,
1678                aggregates,
1679                expected_group_size: _,
1680            } => {
1681                let input_typ = input.typ(outers, params);
1682                let mut column_types = group_key
1683                    .iter()
1684                    .map(|&i| input_typ.column_types[i].clone())
1685                    .collect::<Vec<_>>();
1686                for agg in aggregates {
1687                    column_types.push(agg.typ(outers, &input_typ, params));
1688                }
1689                // TODO(frank): add primary key information.
1690                SqlRelationType::new(column_types)
1691            }
1692            // TODO(frank): check for removal; add primary key information.
1693            HirRelationExpr::Distinct { input }
1694            | HirRelationExpr::Negate { input }
1695            | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1696            HirRelationExpr::Union { base, inputs } => {
1697                let mut base_cols = base.typ(outers, params).column_types;
1698                for input in inputs {
1699                    for (base_col, col) in base_cols
1700                        .iter_mut()
1701                        .zip_eq(input.typ(outers, params).column_types)
1702                    {
1703                        *base_col = base_col.sql_union(&col).unwrap(); // HIR deliberately not using `union`
1704                    }
1705                }
1706                SqlRelationType::new(base_cols)
1707            }
1708        })
1709    }
1710
1711    pub fn arity(&self) -> usize {
1712        match self {
1713            HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1714            HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1715            HirRelationExpr::Let { body, .. } => body.arity(),
1716            HirRelationExpr::LetRec { body, .. } => body.arity(),
1717            HirRelationExpr::Project { outputs, .. } => outputs.len(),
1718            HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1719            HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1720            HirRelationExpr::Filter { input, .. }
1721            | HirRelationExpr::TopK { input, .. }
1722            | HirRelationExpr::Distinct { input }
1723            | HirRelationExpr::Negate { input }
1724            | HirRelationExpr::Threshold { input } => input.arity(),
1725            HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1726            HirRelationExpr::Union { base, .. } => base.arity(),
1727            HirRelationExpr::Reduce {
1728                group_key,
1729                aggregates,
1730                ..
1731            } => group_key.len() + aggregates.len(),
1732        }
1733    }
1734
1735    /// If self is a constant, return the value and the type, otherwise `None`.
1736    pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1737        match self {
1738            Self::Constant { rows, typ } => Some((rows, typ)),
1739            _ => None,
1740        }
1741    }
1742
1743    /// Reports whether this expression contains a column reference to its
1744    /// direct parent scope.
1745    pub fn is_correlated(&self) -> bool {
1746        let mut correlated = false;
1747        #[allow(deprecated)]
1748        self.visit_columns(0, &mut |depth, col| {
1749            if col.level > depth && col.level - depth == 1 {
1750                correlated = true;
1751            }
1752        });
1753        correlated
1754    }
1755
1756    pub fn is_join_identity(&self) -> bool {
1757        match self {
1758            HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1759            _ => false,
1760        }
1761    }
1762
1763    pub fn project(self, outputs: Vec<usize>) -> Self {
1764        if outputs.iter().copied().eq(0..self.arity()) {
1765            // The projection is trivial. Suppress it.
1766            self
1767        } else {
1768            HirRelationExpr::Project {
1769                input: Box::new(self),
1770                outputs,
1771            }
1772        }
1773    }
1774
1775    pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1776        if scalars.is_empty() {
1777            // The map is trivial. Suppress it.
1778            self
1779        } else if let HirRelationExpr::Map {
1780            scalars: old_scalars,
1781            input: _,
1782        } = &mut self
1783        {
1784            // Map applied to a map. Fuse the maps.
1785            old_scalars.extend(scalars);
1786            self
1787        } else {
1788            HirRelationExpr::Map {
1789                input: Box::new(self),
1790                scalars,
1791            }
1792        }
1793    }
1794
1795    pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1796        if let HirRelationExpr::Filter {
1797            input: _,
1798            predicates,
1799        } = &mut self
1800        {
1801            predicates.extend(preds);
1802            predicates.sort();
1803            predicates.dedup();
1804            self
1805        } else {
1806            preds.sort();
1807            preds.dedup();
1808            HirRelationExpr::Filter {
1809                input: Box::new(self),
1810                predicates: preds,
1811            }
1812        }
1813    }
1814
1815    pub fn reduce(
1816        self,
1817        group_key: Vec<usize>,
1818        aggregates: Vec<AggregateExpr>,
1819        expected_group_size: Option<u64>,
1820    ) -> Self {
1821        HirRelationExpr::Reduce {
1822            input: Box::new(self),
1823            group_key,
1824            aggregates,
1825            expected_group_size,
1826        }
1827    }
1828
1829    pub fn top_k(
1830        self,
1831        group_key: Vec<usize>,
1832        order_key: Vec<ColumnOrder>,
1833        limit: Option<HirScalarExpr>,
1834        offset: HirScalarExpr,
1835        expected_group_size: Option<u64>,
1836    ) -> Self {
1837        HirRelationExpr::TopK {
1838            input: Box::new(self),
1839            group_key,
1840            order_key,
1841            limit,
1842            offset,
1843            expected_group_size,
1844        }
1845    }
1846
1847    pub fn negate(self) -> Self {
1848        if let HirRelationExpr::Negate { input } = self {
1849            *input
1850        } else {
1851            HirRelationExpr::Negate {
1852                input: Box::new(self),
1853            }
1854        }
1855    }
1856
1857    pub fn distinct(self) -> Self {
1858        if let HirRelationExpr::Distinct { .. } = self {
1859            self
1860        } else {
1861            HirRelationExpr::Distinct {
1862                input: Box::new(self),
1863            }
1864        }
1865    }
1866
1867    pub fn threshold(self) -> Self {
1868        if let HirRelationExpr::Threshold { .. } = self {
1869            self
1870        } else {
1871            HirRelationExpr::Threshold {
1872                input: Box::new(self),
1873            }
1874        }
1875    }
1876
1877    pub fn union(self, other: Self) -> Self {
1878        let mut terms = Vec::new();
1879        if let HirRelationExpr::Union { base, inputs } = self {
1880            terms.push(*base);
1881            terms.extend(inputs);
1882        } else {
1883            terms.push(self);
1884        }
1885        if let HirRelationExpr::Union { base, inputs } = other {
1886            terms.push(*base);
1887            terms.extend(inputs);
1888        } else {
1889            terms.push(other);
1890        }
1891        HirRelationExpr::Union {
1892            base: Box::new(terms.remove(0)),
1893            inputs: terms,
1894        }
1895    }
1896
1897    pub fn exists(self) -> HirScalarExpr {
1898        HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1899    }
1900
1901    pub fn select(self) -> HirScalarExpr {
1902        HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1903    }
1904
1905    pub fn join(
1906        self,
1907        mut right: HirRelationExpr,
1908        on: HirScalarExpr,
1909        kind: JoinKind,
1910    ) -> HirRelationExpr {
1911        if self.is_join_identity()
1912            && !right.is_correlated()
1913            && on == HirScalarExpr::literal_true()
1914            && kind.can_elide_identity_left_join()
1915        {
1916            // The join can be elided, but we need to adjust column references
1917            // on the right-hand side to account for the removal of the scope
1918            // introduced by the join.
1919            #[allow(deprecated)]
1920            right.visit_columns_mut(0, &mut |depth, col| {
1921                if col.level > depth {
1922                    col.level -= 1;
1923                }
1924            });
1925            right
1926        } else if right.is_join_identity()
1927            && on == HirScalarExpr::literal_true()
1928            && kind.can_elide_identity_right_join()
1929        {
1930            self
1931        } else {
1932            HirRelationExpr::Join {
1933                left: Box::new(self),
1934                right: Box::new(right),
1935                on,
1936                kind,
1937            }
1938        }
1939    }
1940
1941    pub fn take(&mut self) -> HirRelationExpr {
1942        mem::replace(
1943            self,
1944            HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1945        )
1946    }
1947
1948    #[deprecated = "Use `Visit::visit_post`."]
1949    pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1950    where
1951        F: FnMut(&'a Self, usize),
1952    {
1953        #[allow(deprecated)]
1954        let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1955                                                 depth: usize|
1956         -> Result<(), ()> {
1957            f(e, depth);
1958            Ok(())
1959        });
1960    }
1961
1962    #[deprecated = "Use `Visit::try_visit_post`."]
1963    pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1964    where
1965        F: FnMut(&'a Self, usize) -> Result<(), E>,
1966    {
1967        #[allow(deprecated)]
1968        self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1969            e.visit_fallible(depth, f)
1970        })?;
1971        f(self, depth)
1972    }
1973
1974    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1975    pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1976    where
1977        F: FnMut(&'a Self, usize) -> Result<(), E>,
1978    {
1979        match self {
1980            HirRelationExpr::Constant { .. }
1981            | HirRelationExpr::Get { .. }
1982            | HirRelationExpr::CallTable { .. } => (),
1983            HirRelationExpr::Let { body, value, .. } => {
1984                f(value, depth)?;
1985                f(body, depth)?;
1986            }
1987            HirRelationExpr::LetRec {
1988                limit: _,
1989                bindings,
1990                body,
1991            } => {
1992                for (_, _, value, _) in bindings.iter() {
1993                    f(value, depth)?;
1994                }
1995                f(body, depth)?;
1996            }
1997            HirRelationExpr::Project { input, .. } => {
1998                f(input, depth)?;
1999            }
2000            HirRelationExpr::Map { input, .. } => {
2001                f(input, depth)?;
2002            }
2003            HirRelationExpr::Filter { input, .. } => {
2004                f(input, depth)?;
2005            }
2006            HirRelationExpr::Join { left, right, .. } => {
2007                f(left, depth)?;
2008                f(right, depth + 1)?;
2009            }
2010            HirRelationExpr::Reduce { input, .. } => {
2011                f(input, depth)?;
2012            }
2013            HirRelationExpr::Distinct { input } => {
2014                f(input, depth)?;
2015            }
2016            HirRelationExpr::TopK { input, .. } => {
2017                f(input, depth)?;
2018            }
2019            HirRelationExpr::Negate { input } => {
2020                f(input, depth)?;
2021            }
2022            HirRelationExpr::Threshold { input } => {
2023                f(input, depth)?;
2024            }
2025            HirRelationExpr::Union { base, inputs } => {
2026                f(base, depth)?;
2027                for input in inputs {
2028                    f(input, depth)?;
2029                }
2030            }
2031        }
2032        Ok(())
2033    }
2034
2035    #[deprecated = "Use `Visit::visit_mut_post` instead."]
2036    pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2037    where
2038        F: FnMut(&mut Self, usize),
2039    {
2040        #[allow(deprecated)]
2041        let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2042                                                     depth: usize|
2043         -> Result<(), ()> {
2044            f(e, depth);
2045            Ok(())
2046        });
2047    }
2048
2049    #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2050    pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2051    where
2052        F: FnMut(&mut Self, usize) -> Result<(), E>,
2053    {
2054        #[allow(deprecated)]
2055        self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2056            e.visit_mut_fallible(depth, f)
2057        })?;
2058        f(self, depth)
2059    }
2060
2061    #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2062    pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2063    where
2064        F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2065    {
2066        match self {
2067            HirRelationExpr::Constant { .. }
2068            | HirRelationExpr::Get { .. }
2069            | HirRelationExpr::CallTable { .. } => (),
2070            HirRelationExpr::Let { body, value, .. } => {
2071                f(value, depth)?;
2072                f(body, depth)?;
2073            }
2074            HirRelationExpr::LetRec {
2075                limit: _,
2076                bindings,
2077                body,
2078            } => {
2079                for (_, _, value, _) in bindings.iter_mut() {
2080                    f(value, depth)?;
2081                }
2082                f(body, depth)?;
2083            }
2084            HirRelationExpr::Project { input, .. } => {
2085                f(input, depth)?;
2086            }
2087            HirRelationExpr::Map { input, .. } => {
2088                f(input, depth)?;
2089            }
2090            HirRelationExpr::Filter { input, .. } => {
2091                f(input, depth)?;
2092            }
2093            HirRelationExpr::Join { left, right, .. } => {
2094                f(left, depth)?;
2095                f(right, depth + 1)?;
2096            }
2097            HirRelationExpr::Reduce { input, .. } => {
2098                f(input, depth)?;
2099            }
2100            HirRelationExpr::Distinct { input } => {
2101                f(input, depth)?;
2102            }
2103            HirRelationExpr::TopK { input, .. } => {
2104                f(input, depth)?;
2105            }
2106            HirRelationExpr::Negate { input } => {
2107                f(input, depth)?;
2108            }
2109            HirRelationExpr::Threshold { input } => {
2110                f(input, depth)?;
2111            }
2112            HirRelationExpr::Union { base, inputs } => {
2113                f(base, depth)?;
2114                for input in inputs {
2115                    f(input, depth)?;
2116                }
2117            }
2118        }
2119        Ok(())
2120    }
2121
2122    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2123    /// Visits all scalar expressions within the sub-tree of the given relation.
2124    ///
2125    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2126    /// which will be incremented when entering the RHS of a join or a subquery and
2127    /// presented to the supplied function `f`.
2128    pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2129    where
2130        F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2131    {
2132        #[allow(deprecated)]
2133        self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2134                                         depth: usize|
2135         -> Result<(), E> {
2136            match e {
2137                HirRelationExpr::Join { on, .. } => {
2138                    f(on, depth)?;
2139                }
2140                HirRelationExpr::Map { scalars, .. } => {
2141                    for scalar in scalars {
2142                        f(scalar, depth)?;
2143                    }
2144                }
2145                HirRelationExpr::CallTable { exprs, .. } => {
2146                    for expr in exprs {
2147                        f(expr, depth)?;
2148                    }
2149                }
2150                HirRelationExpr::Filter { predicates, .. } => {
2151                    for predicate in predicates {
2152                        f(predicate, depth)?;
2153                    }
2154                }
2155                HirRelationExpr::Reduce { aggregates, .. } => {
2156                    for aggregate in aggregates {
2157                        f(&aggregate.expr, depth)?;
2158                    }
2159                }
2160                HirRelationExpr::TopK { limit, offset, .. } => {
2161                    if let Some(limit) = limit {
2162                        f(limit, depth)?;
2163                    }
2164                    f(offset, depth)?;
2165                }
2166                HirRelationExpr::Union { .. }
2167                | HirRelationExpr::Let { .. }
2168                | HirRelationExpr::LetRec { .. }
2169                | HirRelationExpr::Project { .. }
2170                | HirRelationExpr::Distinct { .. }
2171                | HirRelationExpr::Negate { .. }
2172                | HirRelationExpr::Threshold { .. }
2173                | HirRelationExpr::Constant { .. }
2174                | HirRelationExpr::Get { .. } => (),
2175            }
2176            Ok(())
2177        })
2178    }
2179
2180    #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2181    /// Like `visit_scalar_expressions`, but permits mutating the expressions.
2182    pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2183    where
2184        F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2185    {
2186        #[allow(deprecated)]
2187        self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2188                                             depth: usize|
2189         -> Result<(), E> {
2190            match e {
2191                HirRelationExpr::Join { on, .. } => {
2192                    f(on, depth)?;
2193                }
2194                HirRelationExpr::Map { scalars, .. } => {
2195                    for scalar in scalars.iter_mut() {
2196                        f(scalar, depth)?;
2197                    }
2198                }
2199                HirRelationExpr::CallTable { exprs, .. } => {
2200                    for expr in exprs.iter_mut() {
2201                        f(expr, depth)?;
2202                    }
2203                }
2204                HirRelationExpr::Filter { predicates, .. } => {
2205                    for predicate in predicates.iter_mut() {
2206                        f(predicate, depth)?;
2207                    }
2208                }
2209                HirRelationExpr::Reduce { aggregates, .. } => {
2210                    for aggregate in aggregates.iter_mut() {
2211                        f(&mut aggregate.expr, depth)?;
2212                    }
2213                }
2214                HirRelationExpr::TopK { limit, offset, .. } => {
2215                    if let Some(limit) = limit {
2216                        f(limit, depth)?;
2217                    }
2218                    f(offset, depth)?;
2219                }
2220                HirRelationExpr::Union { .. }
2221                | HirRelationExpr::Let { .. }
2222                | HirRelationExpr::LetRec { .. }
2223                | HirRelationExpr::Project { .. }
2224                | HirRelationExpr::Distinct { .. }
2225                | HirRelationExpr::Negate { .. }
2226                | HirRelationExpr::Threshold { .. }
2227                | HirRelationExpr::Constant { .. }
2228                | HirRelationExpr::Get { .. } => (),
2229            }
2230            Ok(())
2231        })
2232    }
2233
2234    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2235    /// Visits the column references in this relation expression.
2236    ///
2237    /// The `depth` argument should indicate the subquery nesting depth of the expression,
2238    /// which will be incremented when entering the RHS of a join or a subquery and
2239    /// presented to the supplied function `f`.
2240    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2241    where
2242        F: FnMut(usize, &ColumnRef),
2243    {
2244        #[allow(deprecated)]
2245        let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2246                                                           depth: usize|
2247         -> Result<(), ()> {
2248            e.visit_columns(depth, f);
2249            Ok(())
2250        });
2251    }
2252
2253    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2254    /// Like `visit_columns`, but permits mutating the column references.
2255    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2256    where
2257        F: FnMut(usize, &mut ColumnRef),
2258    {
2259        #[allow(deprecated)]
2260        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2261                                                               depth: usize|
2262         -> Result<(), ()> {
2263            e.visit_columns_mut(depth, f);
2264            Ok(())
2265        });
2266    }
2267
2268    /// Replaces any parameter references in the expression with the
2269    /// corresponding datum from `params`.
2270    pub fn bind_parameters(
2271        &mut self,
2272        scx: &StatementContext,
2273        lifetime: QueryLifetime,
2274        params: &Params,
2275    ) -> Result<(), PlanError> {
2276        #[allow(deprecated)]
2277        self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2278            e.bind_parameters(scx, lifetime, params)
2279        })
2280    }
2281
2282    pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2283        let mut contains_parameters = false;
2284        #[allow(deprecated)]
2285        self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2286            if e.contains_parameters() {
2287                contains_parameters = true;
2288            }
2289            Ok::<(), PlanError>(())
2290        })?;
2291        Ok(contains_parameters)
2292    }
2293
2294    /// See the documentation for [`HirScalarExpr::splice_parameters`].
2295    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2296        #[allow(deprecated)]
2297        let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2298                                                               depth: usize|
2299         -> Result<(), ()> {
2300            e.splice_parameters(params, depth);
2301            Ok(())
2302        });
2303    }
2304
2305    /// Constructs a constant collection from specific rows and schema.
2306    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2307        let rows = rows
2308            .into_iter()
2309            .map(move |datums| Row::pack_slice(&datums))
2310            .collect();
2311        HirRelationExpr::Constant { rows, typ }
2312    }
2313
2314    /// A `RowSetFinishing` can only be directly applied to the result of a one-shot select.
2315    /// This function is concerned with maintained queries, e.g., an index or materialized view.
2316    /// Instead of directly applying the given `RowSetFinishing`, it converts the `RowSetFinishing`
2317    /// to a `TopK`, which it then places at the top of `self`. Additionally, it turns the given
2318    /// finishing into a trivial finishing.
2319    pub fn finish_maintained(
2320        &mut self,
2321        finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2322        group_size_hints: GroupSizeHints,
2323    ) {
2324        if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2325            let old_finishing = mem::replace(
2326                finishing,
2327                HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2328            );
2329            *self = HirRelationExpr::top_k(
2330                std::mem::replace(
2331                    self,
2332                    HirRelationExpr::Constant {
2333                        rows: vec![],
2334                        typ: SqlRelationType::new(Vec::new()),
2335                    },
2336                ),
2337                vec![],
2338                old_finishing.order_by,
2339                old_finishing.limit,
2340                old_finishing.offset,
2341                group_size_hints.limit_input_group_size,
2342            )
2343            .project(old_finishing.project);
2344        }
2345    }
2346
2347    /// Returns a trivial finishing, i.e., that does nothing to the result set.
2348    ///
2349    /// (There is also `RowSetFinishing::trivial`, but that is specialized for when the O generic
2350    /// parameter is not an HirScalarExpr anymore.)
2351    pub fn trivial_row_set_finishing_hir(
2352        arity: usize,
2353    ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2354        RowSetFinishing {
2355            order_by: Vec::new(),
2356            limit: None,
2357            offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2358            project: (0..arity).collect(),
2359        }
2360    }
2361
2362    /// True if the finishing does nothing to any result set.
2363    ///
2364    /// (There is also `RowSetFinishing::is_trivial`, but that is specialized for when the O generic
2365    /// parameter is not an HirScalarExpr anymore.)
2366    pub fn is_trivial_row_set_finishing_hir(
2367        rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2368        arity: usize,
2369    ) -> bool {
2370        rsf.limit.is_none()
2371            && rsf.order_by.is_empty()
2372            && rsf
2373                .offset
2374                .clone()
2375                .try_into_literal_int64()
2376                .is_ok_and(|o| o == 0)
2377            && rsf.project.iter().copied().eq(0..arity)
2378    }
2379
2380    /// The HirRelationExpr is considered potentially expensive if and only if
2381    /// at least one of the following conditions is true:
2382    ///
2383    ///  - It contains at least one HirScalarExpr with a function call.
2384    ///  - It contains at least one CallTable or a Reduce operator.
2385    ///  - We run into a RecursionLimitError while analyzing the expression.
2386    ///
2387    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2388    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2389    pub fn could_run_expensive_function(&self) -> bool {
2390        let mut result = false;
2391        if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2392            use HirRelationExpr::*;
2393            use HirScalarExpr::*;
2394
2395            e.visit_children(|scalar: &HirScalarExpr| {
2396                if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2397                    result |= match scalar {
2398                        Column(..)
2399                        | Literal(..)
2400                        | CallUnmaterializable(..)
2401                        | If { .. }
2402                        | Parameter(..)
2403                        | Select(..)
2404                        | Exists(..) => false,
2405                        // Function calls are considered expensive
2406                        CallUnary { .. }
2407                        | CallBinary { .. }
2408                        | CallVariadic { .. }
2409                        | Windowing(..) => true,
2410                    };
2411                }) {
2412                    // Conservatively set `true` on RecursionLimitError.
2413                    result = true;
2414                }
2415            });
2416
2417            // CallTable has a table function; Reduce has an aggregate function.
2418            // Other constructs use MirScalarExpr to run a function
2419            result |= matches!(e, CallTable { .. } | Reduce { .. });
2420        }) {
2421            // Conservatively set `true` on RecursionLimitError.
2422            result = true;
2423        }
2424
2425        result
2426    }
2427
2428    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
2429    pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2430        let mut contains = false;
2431        self.visit_post(&mut |expr| {
2432            expr.visit_children(|expr: &HirScalarExpr| {
2433                contains = contains || expr.contains_temporal()
2434            })
2435        })?;
2436        Ok(contains)
2437    }
2438
2439    /// Whether the expression contains any [`UnmaterializableFunc`] call.
2440    pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2441        let mut contains = false;
2442        self.visit_post(&mut |expr| {
2443            expr.visit_children(|expr: &HirScalarExpr| {
2444                contains = contains || expr.contains_unmaterializable()
2445            })
2446        })?;
2447        Ok(contains)
2448    }
2449
2450    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
2451    /// [`UnmaterializableFunc::MzNow`].
2452    pub fn contains_unmaterializable_except_temporal(&self) -> Result<bool, RecursionLimitError> {
2453        let mut contains = false;
2454        self.visit_post(&mut |expr| {
2455            expr.visit_children(|expr: &HirScalarExpr| {
2456                contains = contains || expr.contains_unmaterializable_except_temporal()
2457            })
2458        })?;
2459        Ok(contains)
2460    }
2461}
2462
2463impl CollectionPlan for HirRelationExpr {
2464    /// Collects the global collections that this HIR expression directly depends on, i.e., that it
2465    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2466    /// (It does explore inside subqueries.)
2467    ///
2468    /// !!!WARNING!!!: this method has an MirRelationExpr counterpart. The two
2469    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2470    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2471        if let Self::Get {
2472            id: Id::Global(id), ..
2473        } = self
2474        {
2475            out.insert(*id);
2476        }
2477        self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2478    }
2479}
2480
2481impl VisitChildren<Self> for HirRelationExpr {
2482    fn visit_children<F>(&self, mut f: F)
2483    where
2484        F: FnMut(&Self),
2485    {
2486        // subqueries of type HirRelationExpr might be wrapped in
2487        // Exists or Select variants within HirScalarExpr trees
2488        // attached at the current node, and we want to visit them as well
2489        VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2490            #[allow(deprecated)]
2491            Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2492                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2493                    f(expr.as_ref())
2494                }
2495                _ => (),
2496            });
2497        });
2498
2499        use HirRelationExpr::*;
2500        match self {
2501            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2502            Let {
2503                name: _,
2504                id: _,
2505                value,
2506                body,
2507            } => {
2508                f(value);
2509                f(body);
2510            }
2511            LetRec {
2512                limit: _,
2513                bindings,
2514                body,
2515            } => {
2516                for (_, _, value, _) in bindings.iter() {
2517                    f(value);
2518                }
2519                f(body);
2520            }
2521            Project { input, outputs: _ } => f(input),
2522            Map { input, scalars: _ } => {
2523                f(input);
2524            }
2525            CallTable { func: _, exprs: _ } => (),
2526            Filter {
2527                input,
2528                predicates: _,
2529            } => {
2530                f(input);
2531            }
2532            Join {
2533                left,
2534                right,
2535                on: _,
2536                kind: _,
2537            } => {
2538                f(left);
2539                f(right);
2540            }
2541            Reduce {
2542                input,
2543                group_key: _,
2544                aggregates: _,
2545                expected_group_size: _,
2546            } => {
2547                f(input);
2548            }
2549            Distinct { input }
2550            | TopK {
2551                input,
2552                group_key: _,
2553                order_key: _,
2554                limit: _,
2555                offset: _,
2556                expected_group_size: _,
2557            }
2558            | Negate { input }
2559            | Threshold { input } => {
2560                f(input);
2561            }
2562            Union { base, inputs } => {
2563                f(base);
2564                for input in inputs {
2565                    f(input);
2566                }
2567            }
2568        }
2569    }
2570
2571    fn visit_mut_children<F>(&mut self, mut f: F)
2572    where
2573        F: FnMut(&mut Self),
2574    {
2575        // subqueries of type HirRelationExpr might be wrapped in
2576        // Exists or Select variants within HirScalarExpr trees
2577        // attached at the current node, and we want to visit them as well
2578        VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2579            #[allow(deprecated)]
2580            Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2581                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2582                    f(expr.as_mut())
2583                }
2584                _ => (),
2585            });
2586        });
2587
2588        use HirRelationExpr::*;
2589        match self {
2590            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2591            Let {
2592                name: _,
2593                id: _,
2594                value,
2595                body,
2596            } => {
2597                f(value);
2598                f(body);
2599            }
2600            LetRec {
2601                limit: _,
2602                bindings,
2603                body,
2604            } => {
2605                for (_, _, value, _) in bindings.iter_mut() {
2606                    f(value);
2607                }
2608                f(body);
2609            }
2610            Project { input, outputs: _ } => f(input),
2611            Map { input, scalars: _ } => {
2612                f(input);
2613            }
2614            CallTable { func: _, exprs: _ } => (),
2615            Filter {
2616                input,
2617                predicates: _,
2618            } => {
2619                f(input);
2620            }
2621            Join {
2622                left,
2623                right,
2624                on: _,
2625                kind: _,
2626            } => {
2627                f(left);
2628                f(right);
2629            }
2630            Reduce {
2631                input,
2632                group_key: _,
2633                aggregates: _,
2634                expected_group_size: _,
2635            } => {
2636                f(input);
2637            }
2638            Distinct { input }
2639            | TopK {
2640                input,
2641                group_key: _,
2642                order_key: _,
2643                limit: _,
2644                offset: _,
2645                expected_group_size: _,
2646            }
2647            | Negate { input }
2648            | Threshold { input } => {
2649                f(input);
2650            }
2651            Union { base, inputs } => {
2652                f(base);
2653                for input in inputs {
2654                    f(input);
2655                }
2656            }
2657        }
2658    }
2659
2660    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2661    where
2662        F: FnMut(&Self) -> Result<(), E>,
2663        E: From<RecursionLimitError>,
2664    {
2665        // subqueries of type HirRelationExpr might be wrapped in
2666        // Exists or Select variants within HirScalarExpr trees
2667        // attached at the current node, and we want to visit them as well
2668        VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2669            Visit::try_visit_post(expr, &mut |expr| match expr {
2670                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2671                    f(expr.as_ref())
2672                }
2673                _ => Ok(()),
2674            })
2675        })?;
2676
2677        use HirRelationExpr::*;
2678        match self {
2679            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2680            Let {
2681                name: _,
2682                id: _,
2683                value,
2684                body,
2685            } => {
2686                f(value)?;
2687                f(body)?;
2688            }
2689            LetRec {
2690                limit: _,
2691                bindings,
2692                body,
2693            } => {
2694                for (_, _, value, _) in bindings.iter() {
2695                    f(value)?;
2696                }
2697                f(body)?;
2698            }
2699            Project { input, outputs: _ } => f(input)?,
2700            Map { input, scalars: _ } => {
2701                f(input)?;
2702            }
2703            CallTable { func: _, exprs: _ } => (),
2704            Filter {
2705                input,
2706                predicates: _,
2707            } => {
2708                f(input)?;
2709            }
2710            Join {
2711                left,
2712                right,
2713                on: _,
2714                kind: _,
2715            } => {
2716                f(left)?;
2717                f(right)?;
2718            }
2719            Reduce {
2720                input,
2721                group_key: _,
2722                aggregates: _,
2723                expected_group_size: _,
2724            } => {
2725                f(input)?;
2726            }
2727            Distinct { input }
2728            | TopK {
2729                input,
2730                group_key: _,
2731                order_key: _,
2732                limit: _,
2733                offset: _,
2734                expected_group_size: _,
2735            }
2736            | Negate { input }
2737            | Threshold { input } => {
2738                f(input)?;
2739            }
2740            Union { base, inputs } => {
2741                f(base)?;
2742                for input in inputs {
2743                    f(input)?;
2744                }
2745            }
2746        }
2747        Ok(())
2748    }
2749
2750    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2751    where
2752        F: FnMut(&mut Self) -> Result<(), E>,
2753        E: From<RecursionLimitError>,
2754    {
2755        // subqueries of type HirRelationExpr might be wrapped in
2756        // Exists or Select variants within HirScalarExpr trees
2757        // attached at the current node, and we want to visit them as well
2758        VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2759            Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2760                HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2761                    f(expr.as_mut())
2762                }
2763                _ => Ok(()),
2764            })
2765        })?;
2766
2767        use HirRelationExpr::*;
2768        match self {
2769            Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2770            Let {
2771                name: _,
2772                id: _,
2773                value,
2774                body,
2775            } => {
2776                f(value)?;
2777                f(body)?;
2778            }
2779            LetRec {
2780                limit: _,
2781                bindings,
2782                body,
2783            } => {
2784                for (_, _, value, _) in bindings.iter_mut() {
2785                    f(value)?;
2786                }
2787                f(body)?;
2788            }
2789            Project { input, outputs: _ } => f(input)?,
2790            Map { input, scalars: _ } => {
2791                f(input)?;
2792            }
2793            CallTable { func: _, exprs: _ } => (),
2794            Filter {
2795                input,
2796                predicates: _,
2797            } => {
2798                f(input)?;
2799            }
2800            Join {
2801                left,
2802                right,
2803                on: _,
2804                kind: _,
2805            } => {
2806                f(left)?;
2807                f(right)?;
2808            }
2809            Reduce {
2810                input,
2811                group_key: _,
2812                aggregates: _,
2813                expected_group_size: _,
2814            } => {
2815                f(input)?;
2816            }
2817            Distinct { input }
2818            | TopK {
2819                input,
2820                group_key: _,
2821                order_key: _,
2822                limit: _,
2823                offset: _,
2824                expected_group_size: _,
2825            }
2826            | Negate { input }
2827            | Threshold { input } => {
2828                f(input)?;
2829            }
2830            Union { base, inputs } => {
2831                f(base)?;
2832                for input in inputs {
2833                    f(input)?;
2834                }
2835            }
2836        }
2837        Ok(())
2838    }
2839}
2840
2841impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2842    fn visit_children<F>(&self, mut f: F)
2843    where
2844        F: FnMut(&HirScalarExpr),
2845    {
2846        use HirRelationExpr::*;
2847        match self {
2848            Constant { rows: _, typ: _ }
2849            | Get { id: _, typ: _ }
2850            | Let {
2851                name: _,
2852                id: _,
2853                value: _,
2854                body: _,
2855            }
2856            | LetRec {
2857                limit: _,
2858                bindings: _,
2859                body: _,
2860            }
2861            | Project {
2862                input: _,
2863                outputs: _,
2864            } => (),
2865            Map { input: _, scalars } => {
2866                for scalar in scalars {
2867                    f(scalar);
2868                }
2869            }
2870            CallTable { func: _, exprs } => {
2871                for expr in exprs {
2872                    f(expr);
2873                }
2874            }
2875            Filter {
2876                input: _,
2877                predicates,
2878            } => {
2879                for predicate in predicates {
2880                    f(predicate);
2881                }
2882            }
2883            Join {
2884                left: _,
2885                right: _,
2886                on,
2887                kind: _,
2888            } => f(on),
2889            Reduce {
2890                input: _,
2891                group_key: _,
2892                aggregates,
2893                expected_group_size: _,
2894            } => {
2895                for aggregate in aggregates {
2896                    f(aggregate.expr.as_ref());
2897                }
2898            }
2899            TopK {
2900                input: _,
2901                group_key: _,
2902                order_key: _,
2903                limit,
2904                offset,
2905                expected_group_size: _,
2906            } => {
2907                if let Some(limit) = limit {
2908                    f(limit)
2909                }
2910                f(offset)
2911            }
2912            Distinct { input: _ }
2913            | Negate { input: _ }
2914            | Threshold { input: _ }
2915            | Union { base: _, inputs: _ } => (),
2916        }
2917    }
2918
2919    fn visit_mut_children<F>(&mut self, mut f: F)
2920    where
2921        F: FnMut(&mut HirScalarExpr),
2922    {
2923        use HirRelationExpr::*;
2924        match self {
2925            Constant { rows: _, typ: _ }
2926            | Get { id: _, typ: _ }
2927            | Let {
2928                name: _,
2929                id: _,
2930                value: _,
2931                body: _,
2932            }
2933            | LetRec {
2934                limit: _,
2935                bindings: _,
2936                body: _,
2937            }
2938            | Project {
2939                input: _,
2940                outputs: _,
2941            } => (),
2942            Map { input: _, scalars } => {
2943                for scalar in scalars {
2944                    f(scalar);
2945                }
2946            }
2947            CallTable { func: _, exprs } => {
2948                for expr in exprs {
2949                    f(expr);
2950                }
2951            }
2952            Filter {
2953                input: _,
2954                predicates,
2955            } => {
2956                for predicate in predicates {
2957                    f(predicate);
2958                }
2959            }
2960            Join {
2961                left: _,
2962                right: _,
2963                on,
2964                kind: _,
2965            } => f(on),
2966            Reduce {
2967                input: _,
2968                group_key: _,
2969                aggregates,
2970                expected_group_size: _,
2971            } => {
2972                for aggregate in aggregates {
2973                    f(aggregate.expr.as_mut());
2974                }
2975            }
2976            TopK {
2977                input: _,
2978                group_key: _,
2979                order_key: _,
2980                limit,
2981                offset,
2982                expected_group_size: _,
2983            } => {
2984                if let Some(limit) = limit {
2985                    f(limit)
2986                }
2987                f(offset)
2988            }
2989            Distinct { input: _ }
2990            | Negate { input: _ }
2991            | Threshold { input: _ }
2992            | Union { base: _, inputs: _ } => (),
2993        }
2994    }
2995
2996    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2997    where
2998        F: FnMut(&HirScalarExpr) -> Result<(), E>,
2999        E: From<RecursionLimitError>,
3000    {
3001        use HirRelationExpr::*;
3002        match self {
3003            Constant { rows: _, typ: _ }
3004            | Get { id: _, typ: _ }
3005            | Let {
3006                name: _,
3007                id: _,
3008                value: _,
3009                body: _,
3010            }
3011            | LetRec {
3012                limit: _,
3013                bindings: _,
3014                body: _,
3015            }
3016            | Project {
3017                input: _,
3018                outputs: _,
3019            } => (),
3020            Map { input: _, scalars } => {
3021                for scalar in scalars {
3022                    f(scalar)?;
3023                }
3024            }
3025            CallTable { func: _, exprs } => {
3026                for expr in exprs {
3027                    f(expr)?;
3028                }
3029            }
3030            Filter {
3031                input: _,
3032                predicates,
3033            } => {
3034                for predicate in predicates {
3035                    f(predicate)?;
3036                }
3037            }
3038            Join {
3039                left: _,
3040                right: _,
3041                on,
3042                kind: _,
3043            } => f(on)?,
3044            Reduce {
3045                input: _,
3046                group_key: _,
3047                aggregates,
3048                expected_group_size: _,
3049            } => {
3050                for aggregate in aggregates {
3051                    f(aggregate.expr.as_ref())?;
3052                }
3053            }
3054            TopK {
3055                input: _,
3056                group_key: _,
3057                order_key: _,
3058                limit,
3059                offset,
3060                expected_group_size: _,
3061            } => {
3062                if let Some(limit) = limit {
3063                    f(limit)?
3064                }
3065                f(offset)?
3066            }
3067            Distinct { input: _ }
3068            | Negate { input: _ }
3069            | Threshold { input: _ }
3070            | Union { base: _, inputs: _ } => (),
3071        }
3072        Ok(())
3073    }
3074
3075    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3076    where
3077        F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3078        E: From<RecursionLimitError>,
3079    {
3080        use HirRelationExpr::*;
3081        match self {
3082            Constant { rows: _, typ: _ }
3083            | Get { id: _, typ: _ }
3084            | Let {
3085                name: _,
3086                id: _,
3087                value: _,
3088                body: _,
3089            }
3090            | LetRec {
3091                limit: _,
3092                bindings: _,
3093                body: _,
3094            }
3095            | Project {
3096                input: _,
3097                outputs: _,
3098            } => (),
3099            Map { input: _, scalars } => {
3100                for scalar in scalars {
3101                    f(scalar)?;
3102                }
3103            }
3104            CallTable { func: _, exprs } => {
3105                for expr in exprs {
3106                    f(expr)?;
3107                }
3108            }
3109            Filter {
3110                input: _,
3111                predicates,
3112            } => {
3113                for predicate in predicates {
3114                    f(predicate)?;
3115                }
3116            }
3117            Join {
3118                left: _,
3119                right: _,
3120                on,
3121                kind: _,
3122            } => f(on)?,
3123            Reduce {
3124                input: _,
3125                group_key: _,
3126                aggregates,
3127                expected_group_size: _,
3128            } => {
3129                for aggregate in aggregates {
3130                    f(aggregate.expr.as_mut())?;
3131                }
3132            }
3133            TopK {
3134                input: _,
3135                group_key: _,
3136                order_key: _,
3137                limit,
3138                offset,
3139                expected_group_size: _,
3140            } => {
3141                if let Some(limit) = limit {
3142                    f(limit)?
3143                }
3144                f(offset)?
3145            }
3146            Distinct { input: _ }
3147            | Negate { input: _ }
3148            | Threshold { input: _ }
3149            | Union { base: _, inputs: _ } => (),
3150        }
3151        Ok(())
3152    }
3153}
3154
3155impl HirScalarExpr {
3156    pub fn name(&self) -> Option<Arc<str>> {
3157        use HirScalarExpr::*;
3158        match self {
3159            Column(_, name)
3160            | Parameter(_, name)
3161            | Literal(_, _, name)
3162            | CallUnmaterializable(_, name)
3163            | CallUnary { name, .. }
3164            | CallBinary { name, .. }
3165            | CallVariadic { name, .. }
3166            | If { name, .. }
3167            | Exists(_, name)
3168            | Select(_, name)
3169            | Windowing(_, name) => name.0.clone(),
3170        }
3171    }
3172
3173    /// Replaces any parameter references in the expression with the
3174    /// corresponding datum in `params`.
3175    pub fn bind_parameters(
3176        &mut self,
3177        scx: &StatementContext,
3178        lifetime: QueryLifetime,
3179        params: &Params,
3180    ) -> Result<(), PlanError> {
3181        #[allow(deprecated)]
3182        self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3183            if let HirScalarExpr::Parameter(n, name) = e {
3184                let datum = match params.datums.iter().nth(*n - 1) {
3185                    None => return Err(PlanError::UnknownParameter(*n)),
3186                    Some(datum) => datum,
3187                };
3188                let scalar_type = &params.execute_types[*n - 1];
3189                let row = Row::pack([datum]);
3190                let column_type = scalar_type.clone().nullable(datum.is_null());
3191
3192                let name = if let Some(name) = &name.0 {
3193                    Some(Arc::clone(name))
3194                } else {
3195                    Some(Arc::from(format!("${n}")))
3196                };
3197
3198                let qcx = QueryContext::root(scx, lifetime);
3199                let ecx = execute_expr_context(&qcx);
3200
3201                *e = plan_cast(
3202                    &ecx,
3203                    *EXECUTE_CAST_CONTEXT,
3204                    HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3205                    &params.expected_types[*n - 1],
3206                )
3207                .expect("checked in plan_params");
3208            }
3209            Ok(())
3210        })
3211    }
3212
3213    /// Like [`HirScalarExpr::bind_parameters`], except that parameters are
3214    /// replaced with the corresponding expression fragment from `params` rather
3215    /// than a datum.
3216    ///
3217    /// Specifically, the parameter `$1` will be replaced with `params[0]`, the
3218    /// parameter `$2` will be replaced with `params[1]`, and so on. Parameters
3219    /// in `self` that refer to invalid indices of `params` will cause a panic.
3220    ///
3221    /// Column references in parameters will be corrected to account for the
3222    /// depth at which they are spliced.
3223    pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3224        #[allow(deprecated)]
3225        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3226                                                        e: &mut HirScalarExpr|
3227         -> Result<(), ()> {
3228            if let HirScalarExpr::Parameter(i, _name) = e {
3229                *e = params[*i - 1].clone();
3230                // Correct any column references in the parameter expression for
3231                // its new depth.
3232                e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3233                    if col.level >= d {
3234                        col.level += depth
3235                    }
3236                });
3237            }
3238            Ok(())
3239        });
3240    }
3241
3242    /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call.
3243    pub fn contains_temporal(&self) -> bool {
3244        let mut contains = false;
3245        #[allow(deprecated)]
3246        self.visit_post_nolimit(&mut |e| {
3247            if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3248                contains = true;
3249            }
3250        });
3251        contains
3252    }
3253
3254    /// Whether the expression contains any [`UnmaterializableFunc`] call.
3255    pub fn contains_unmaterializable(&self) -> bool {
3256        let mut contains = false;
3257        #[allow(deprecated)]
3258        self.visit_post_nolimit(&mut |e| {
3259            if let Self::CallUnmaterializable(_, _) = e {
3260                contains = true;
3261            }
3262        });
3263        contains
3264    }
3265
3266    /// Whether the expression contains any [`UnmaterializableFunc`] call other than
3267    /// [`UnmaterializableFunc::MzNow`].
3268    pub fn contains_unmaterializable_except_temporal(&self) -> bool {
3269        let mut contains = false;
3270        #[allow(deprecated)]
3271        self.visit_post_nolimit(&mut |e| {
3272            if let Self::CallUnmaterializable(f, _) = e {
3273                if *f != UnmaterializableFunc::MzNow {
3274                    contains = true;
3275                }
3276            }
3277        });
3278        contains
3279    }
3280
3281    /// Constructs an unnamed column reference in the current scope.
3282    /// Use [`HirScalarExpr::named_column`] when a name is known.
3283    /// Use [`HirScalarExpr::unnamed_column`] for a `ColumnRef`.
3284    pub fn column(index: usize) -> HirScalarExpr {
3285        HirScalarExpr::Column(
3286            ColumnRef {
3287                level: 0,
3288                column: index,
3289            },
3290            TreatAsEqual(None),
3291        )
3292    }
3293
3294    /// Constructs an unnamed column reference.
3295    pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3296        HirScalarExpr::Column(cr, TreatAsEqual(None))
3297    }
3298
3299    /// Constructs a named column reference.
3300    /// Names are interned by a `NameManager`.
3301    pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3302        HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3303    }
3304
3305    pub fn parameter(n: usize) -> HirScalarExpr {
3306        HirScalarExpr::Parameter(n, TreatAsEqual(None))
3307    }
3308
3309    pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3310        let col_type = scalar_type.nullable(datum.is_null());
3311        soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3312        let row = Row::pack([datum]);
3313        HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3314    }
3315
3316    pub fn literal_true() -> HirScalarExpr {
3317        HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3318    }
3319
3320    pub fn literal_false() -> HirScalarExpr {
3321        HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3322    }
3323
3324    pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3325        HirScalarExpr::literal(Datum::Null, scalar_type)
3326    }
3327
3328    pub fn literal_1d_array(
3329        datums: Vec<Datum>,
3330        element_scalar_type: SqlScalarType,
3331    ) -> Result<HirScalarExpr, PlanError> {
3332        let scalar_type = match element_scalar_type {
3333            SqlScalarType::Array(_) => {
3334                sql_bail!("cannot build array from array type");
3335            }
3336            typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3337        };
3338
3339        let mut row = Row::default();
3340        row.packer()
3341            .try_push_array(
3342                &[ArrayDimension {
3343                    lower_bound: 1,
3344                    length: datums.len(),
3345                }],
3346                datums,
3347            )
3348            .expect("array constructed to be valid");
3349
3350        Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3351    }
3352
3353    pub fn as_literal(&self) -> Option<Datum<'_>> {
3354        if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3355            Some(row.unpack_first())
3356        } else {
3357            None
3358        }
3359    }
3360
3361    pub fn is_literal_true(&self) -> bool {
3362        Some(Datum::True) == self.as_literal()
3363    }
3364
3365    pub fn is_literal_false(&self) -> bool {
3366        Some(Datum::False) == self.as_literal()
3367    }
3368
3369    pub fn is_literal_null(&self) -> bool {
3370        Some(Datum::Null) == self.as_literal()
3371    }
3372
3373    /// Return true iff `self` consists only of literals, materializable function calls, and
3374    /// if-else statements.
3375    pub fn is_constant(&self) -> bool {
3376        let mut worklist = vec![self];
3377        while let Some(expr) = worklist.pop() {
3378            match expr {
3379                Self::Literal(..) => {
3380                    // leaf node, do nothing
3381                }
3382                Self::CallUnary { expr, .. } => {
3383                    worklist.push(expr);
3384                }
3385                Self::CallBinary {
3386                    func: _,
3387                    expr1,
3388                    expr2,
3389                    name: _,
3390                } => {
3391                    worklist.push(expr1);
3392                    worklist.push(expr2);
3393                }
3394                Self::CallVariadic {
3395                    func: _,
3396                    exprs,
3397                    name: _,
3398                } => {
3399                    worklist.extend(exprs.iter());
3400                }
3401                // (CallUnmaterializable is not allowed)
3402                Self::If {
3403                    cond,
3404                    then,
3405                    els,
3406                    name: _,
3407                } => {
3408                    worklist.push(cond);
3409                    worklist.push(then);
3410                    worklist.push(els);
3411                }
3412                _ => {
3413                    return false; // Any other node makes `self` non-constant.
3414                }
3415            }
3416        }
3417        true
3418    }
3419
3420    pub fn call_unary(self, func: UnaryFunc) -> Self {
3421        HirScalarExpr::CallUnary {
3422            func,
3423            expr: Box::new(self),
3424            name: NameMetadata::default(),
3425        }
3426    }
3427
3428    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3429        HirScalarExpr::CallBinary {
3430            func: func.into(),
3431            expr1: Box::new(self),
3432            expr2: Box::new(other),
3433            name: NameMetadata::default(),
3434        }
3435    }
3436
3437    pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3438        HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3439    }
3440
3441    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
3442        HirScalarExpr::CallVariadic {
3443            func: func.into(),
3444            exprs,
3445            name: NameMetadata::default(),
3446        }
3447    }
3448
3449    pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3450        HirScalarExpr::If {
3451            cond: Box::new(cond),
3452            then: Box::new(then),
3453            els: Box::new(els),
3454            name: NameMetadata::default(),
3455        }
3456    }
3457
3458    pub fn windowing(expr: WindowExpr) -> Self {
3459        HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3460    }
3461
3462    pub fn or(self, other: Self) -> Self {
3463        HirScalarExpr::call_variadic(Or, vec![self, other])
3464    }
3465
3466    pub fn and(self, other: Self) -> Self {
3467        HirScalarExpr::call_variadic(And, vec![self, other])
3468    }
3469
3470    pub fn not(self) -> Self {
3471        self.call_unary(UnaryFunc::Not(func::Not))
3472    }
3473
3474    pub fn call_is_null(self) -> Self {
3475        self.call_unary(UnaryFunc::IsNull(func::IsNull))
3476    }
3477
3478    /// Calls AND with the given arguments. Simplifies if 0 or 1 args.
3479    pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3480        match args.len() {
3481            0 => HirScalarExpr::literal_true(), // Same as unit_of_and_or, but that's MirScalarExpr
3482            1 => args.swap_remove(0),
3483            _ => HirScalarExpr::call_variadic(And, args),
3484        }
3485    }
3486
3487    /// Calls OR with the given arguments. Simplifies if 0 or 1 args.
3488    pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3489        match args.len() {
3490            0 => HirScalarExpr::literal_false(), // Same as unit_of_and_or, but that's MirScalarExpr
3491            1 => args.swap_remove(0),
3492            _ => HirScalarExpr::call_variadic(Or, args),
3493        }
3494    }
3495
3496    pub fn take(&mut self) -> Self {
3497        mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3498    }
3499
3500    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3501    /// Visits the column references in this scalar expression.
3502    ///
3503    /// The `depth` argument should indicate the subquery nesting depth of the expression,
3504    /// which will be incremented with each subquery entered and presented to the supplied
3505    /// function `f`.
3506    pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3507    where
3508        F: FnMut(usize, &ColumnRef),
3509    {
3510        #[allow(deprecated)]
3511        let _ = self.visit_recursively(depth, &mut |depth: usize,
3512                                                    e: &HirScalarExpr|
3513         -> Result<(), ()> {
3514            if let HirScalarExpr::Column(col, _name) = e {
3515                f(depth, col)
3516            }
3517            Ok(())
3518        });
3519    }
3520
3521    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3522    /// Like `visit_columns`, but permits mutating the column references.
3523    pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3524    where
3525        F: FnMut(usize, &mut ColumnRef),
3526    {
3527        #[allow(deprecated)]
3528        let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3529                                                        e: &mut HirScalarExpr|
3530         -> Result<(), ()> {
3531            if let HirScalarExpr::Column(col, _name) = e {
3532                f(depth, col)
3533            }
3534            Ok(())
3535        });
3536    }
3537
3538    /// Visits those column references in this scalar expression that refer to the root
3539    /// level. These include column references that are at the root level, as well as column
3540    /// references that are at a deeper subquery nesting depth, but refer back to the root level.
3541    /// (Note that even if `self` is embedded inside a larger expression, we consider the
3542    /// "root level" to be `self`'s level.)
3543    pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3544    where
3545        F: FnMut(usize),
3546    {
3547        #[allow(deprecated)]
3548        let _ = self.visit_recursively(0, &mut |depth: usize,
3549                                                e: &HirScalarExpr|
3550         -> Result<(), ()> {
3551            if let HirScalarExpr::Column(col, _name) = e {
3552                if col.level == depth {
3553                    f(col.column)
3554                }
3555            }
3556            Ok(())
3557        });
3558    }
3559
3560    /// Like `visit_columns_referring_to_root_level`, but permits mutating the column references.
3561    pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3562    where
3563        F: FnMut(&mut usize),
3564    {
3565        #[allow(deprecated)]
3566        let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3567                                                    e: &mut HirScalarExpr|
3568         -> Result<(), ()> {
3569            if let HirScalarExpr::Column(col, _name) = e {
3570                if col.level == depth {
3571                    f(&mut col.column)
3572                }
3573            }
3574            Ok(())
3575        });
3576    }
3577
3578    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3579    /// Like `visit` but it enters the subqueries visiting the scalar expressions contained
3580    /// in them. It takes the current depth of the expression and increases it when
3581    /// entering a subquery.
3582    pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3583    where
3584        F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3585    {
3586        match self {
3587            HirScalarExpr::Literal(..)
3588            | HirScalarExpr::Parameter(..)
3589            | HirScalarExpr::CallUnmaterializable(..)
3590            | HirScalarExpr::Column(..) => (),
3591            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3592            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3593                expr1.visit_recursively(depth, f)?;
3594                expr2.visit_recursively(depth, f)?;
3595            }
3596            HirScalarExpr::CallVariadic { exprs, .. } => {
3597                for expr in exprs {
3598                    expr.visit_recursively(depth, f)?;
3599                }
3600            }
3601            HirScalarExpr::If {
3602                cond,
3603                then,
3604                els,
3605                name: _,
3606            } => {
3607                cond.visit_recursively(depth, f)?;
3608                then.visit_recursively(depth, f)?;
3609                els.visit_recursively(depth, f)?;
3610            }
3611            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3612                #[allow(deprecated)]
3613                expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3614                    e.visit_recursively(depth, f)
3615                })?;
3616            }
3617            HirScalarExpr::Windowing(expr, _name) => {
3618                expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3619            }
3620        }
3621        f(depth, self)
3622    }
3623
3624    #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3625    /// Like `visit_recursively`, but permits mutating the scalar expressions.
3626    pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3627    where
3628        F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3629    {
3630        match self {
3631            HirScalarExpr::Literal(..)
3632            | HirScalarExpr::Parameter(..)
3633            | HirScalarExpr::CallUnmaterializable(..)
3634            | HirScalarExpr::Column(..) => (),
3635            HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3636            HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3637                expr1.visit_recursively_mut(depth, f)?;
3638                expr2.visit_recursively_mut(depth, f)?;
3639            }
3640            HirScalarExpr::CallVariadic { exprs, .. } => {
3641                for expr in exprs {
3642                    expr.visit_recursively_mut(depth, f)?;
3643                }
3644            }
3645            HirScalarExpr::If {
3646                cond,
3647                then,
3648                els,
3649                name: _,
3650            } => {
3651                cond.visit_recursively_mut(depth, f)?;
3652                then.visit_recursively_mut(depth, f)?;
3653                els.visit_recursively_mut(depth, f)?;
3654            }
3655            HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3656                #[allow(deprecated)]
3657                expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3658                    e.visit_recursively_mut(depth, f)
3659                })?;
3660            }
3661            HirScalarExpr::Windowing(expr, _name) => {
3662                expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3663            }
3664        }
3665        f(depth, self)
3666    }
3667
3668    /// Attempts to simplify self into a literal.
3669    ///
3670    /// Returns None if self is not constant and therefore can't be simplified to a literal, or if
3671    /// an evaluation error occurs during simplification, or if self contains
3672    /// - a subquery
3673    /// - a column reference to an outer level
3674    /// - a parameter
3675    /// - a window function call
3676    fn simplify_to_literal(self) -> Option<Row> {
3677        let mut expr = self
3678            .lower_uncorrelated(crate::plan::lowering::Config::default())
3679            .ok()?;
3680        // Using MIR evaluation with repr types is fine here: the
3681        // result is an untyped Row, so any intermediate type
3682        // canonicalization is discarded.
3683        expr.reduce(&[]);
3684        match expr {
3685            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3686            _ => None,
3687        }
3688    }
3689
3690    /// Simplifies self into a literal. If this is not possible (e.g., because self is not constant
3691    /// or an evaluation error occurs during simplification), it returns
3692    /// [`PlanError::ConstantExpressionSimplificationFailed`].
3693    ///
3694    /// The returned error is an _internal_ error if the expression contains
3695    /// - a subquery
3696    /// - a column reference to an outer level
3697    /// - a parameter
3698    /// - a window function call
3699    ///
3700    /// TODO: use this everywhere instead of `simplify_to_literal`, so that we don't hide the error
3701    /// msg.
3702    fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3703        let mut expr = self
3704            .lower_uncorrelated(crate::plan::lowering::Config::default())
3705            .map_err(|err| {
3706                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3707            })?;
3708        // Using MIR evaluation with repr types is fine here: the
3709        // result is an untyped Row, so any intermediate type
3710        // canonicalization is discarded.
3711        expr.reduce(&[]);
3712        match expr {
3713            mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3714            mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3715                PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3716            ),
3717            _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3718                "Not a constant".to_string(),
3719            )),
3720        }
3721    }
3722
3723    /// Attempts to simplify this expression to a literal 64-bit integer.
3724    ///
3725    /// Returns `None` if this expression cannot be simplified, e.g. because it
3726    /// contains non-literal values.
3727    ///
3728    /// # Panics
3729    ///
3730    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3731    pub fn into_literal_int64(self) -> Option<i64> {
3732        self.simplify_to_literal().and_then(|row| {
3733            let datum = row.unpack_first();
3734            if datum.is_null() {
3735                None
3736            } else {
3737                Some(datum.unwrap_int64())
3738            }
3739        })
3740    }
3741
3742    /// Attempts to simplify this expression to a literal string.
3743    ///
3744    /// Returns `None` if this expression cannot be simplified, e.g. because it
3745    /// contains non-literal values.
3746    ///
3747    /// # Panics
3748    ///
3749    /// Panics if this expression does not have type [`SqlScalarType::String`].
3750    pub fn into_literal_string(self) -> Option<String> {
3751        self.simplify_to_literal().and_then(|row| {
3752            let datum = row.unpack_first();
3753            if datum.is_null() {
3754                None
3755            } else {
3756                Some(datum.unwrap_str().to_owned())
3757            }
3758        })
3759    }
3760
3761    /// Attempts to simplify this expression to a literal MzTimestamp.
3762    ///
3763    /// Returns `None` if the expression simplifies to `null` or if the expression cannot be
3764    /// simplified, e.g. because it contains non-literal values or a cast fails.
3765    ///
3766    /// TODO: Make this (and the other similar fns above) return Result, so that we can show the
3767    /// error when it fails. (E.g., there can be non-trivial cast errors.)
3768    /// See `try_into_literal_int64` as an example.
3769    ///
3770    /// # Panics
3771    ///
3772    /// Panics if this expression does not have type [`SqlScalarType::MzTimestamp`].
3773    pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3774        self.simplify_to_literal().and_then(|row| {
3775            let datum = row.unpack_first();
3776            if datum.is_null() {
3777                None
3778            } else {
3779                Some(datum.unwrap_mz_timestamp())
3780            }
3781        })
3782    }
3783
3784    /// Attempts to simplify this expression of [`SqlScalarType::Int64`] to a literal Int64 and
3785    /// returns it as an i64.
3786    ///
3787    /// Returns `PlanError::ConstantExpressionSimplificationFailed` if
3788    /// - it's not a constant expression (as determined by `is_constant`)
3789    /// - evaluates to null
3790    /// - an EvalError occurs during evaluation (e.g., a cast fails)
3791    ///
3792    /// # Panics
3793    ///
3794    /// Panics if this expression does not have type [`SqlScalarType::Int64`].
3795    pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3796        // TODO: add the `is_constant` check also to all the other into_literal_... (by adding it to
3797        // `simplify_to_literal`), but those should be just soft_asserts at first that it doesn't
3798        // actually happen that it's weaker than `reduce`, and then add them for real after 1 week.
3799        // (Without the is_constant check, lower_uncorrelated's preconditions spill out to be
3800        // preconditions also of all the other into_literal_... functions.)
3801        if !self.is_constant() {
3802            return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3803                "Expected a constant expression, got {}",
3804                self
3805            )));
3806        }
3807        self.clone()
3808            .simplify_to_literal_with_result()
3809            .and_then(|row| {
3810                let datum = row.unpack_first();
3811                if datum.is_null() {
3812                    Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3813                        "Expected an expression that evaluates to a non-null value, got {}",
3814                        self
3815                    )))
3816                } else {
3817                    Ok(datum.unwrap_int64())
3818                }
3819            })
3820    }
3821
3822    pub fn contains_parameters(&self) -> bool {
3823        let mut contains_parameters = false;
3824        #[allow(deprecated)]
3825        let _ = self.visit_recursively(0, &mut |_depth: usize,
3826                                                expr: &HirScalarExpr|
3827         -> Result<(), ()> {
3828            if let HirScalarExpr::Parameter(..) = expr {
3829                contains_parameters = true;
3830            }
3831            Ok(())
3832        });
3833        contains_parameters
3834    }
3835}
3836
3837impl VisitChildren<Self> for HirScalarExpr {
3838    fn visit_children<F>(&self, mut f: F)
3839    where
3840        F: FnMut(&Self),
3841    {
3842        use HirScalarExpr::*;
3843        match self {
3844            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3845            CallUnary { expr, .. } => f(expr),
3846            CallBinary { expr1, expr2, .. } => {
3847                f(expr1);
3848                f(expr2);
3849            }
3850            CallVariadic { exprs, .. } => {
3851                for expr in exprs {
3852                    f(expr);
3853                }
3854            }
3855            If {
3856                cond,
3857                then,
3858                els,
3859                name: _,
3860            } => {
3861                f(cond);
3862                f(then);
3863                f(els);
3864            }
3865            Exists(..) | Select(..) => (),
3866            Windowing(expr, _name) => expr.visit_children(f),
3867        }
3868    }
3869
3870    fn visit_mut_children<F>(&mut self, mut f: F)
3871    where
3872        F: FnMut(&mut Self),
3873    {
3874        use HirScalarExpr::*;
3875        match self {
3876            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3877            CallUnary { expr, .. } => f(expr),
3878            CallBinary { expr1, expr2, .. } => {
3879                f(expr1);
3880                f(expr2);
3881            }
3882            CallVariadic { exprs, .. } => {
3883                for expr in exprs {
3884                    f(expr);
3885                }
3886            }
3887            If {
3888                cond,
3889                then,
3890                els,
3891                name: _,
3892            } => {
3893                f(cond);
3894                f(then);
3895                f(els);
3896            }
3897            Exists(..) | Select(..) => (),
3898            Windowing(expr, _name) => expr.visit_mut_children(f),
3899        }
3900    }
3901
3902    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3903    where
3904        F: FnMut(&Self) -> Result<(), E>,
3905        E: From<RecursionLimitError>,
3906    {
3907        use HirScalarExpr::*;
3908        match self {
3909            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3910            CallUnary { expr, .. } => f(expr)?,
3911            CallBinary { expr1, expr2, .. } => {
3912                f(expr1)?;
3913                f(expr2)?;
3914            }
3915            CallVariadic { exprs, .. } => {
3916                for expr in exprs {
3917                    f(expr)?;
3918                }
3919            }
3920            If {
3921                cond,
3922                then,
3923                els,
3924                name: _,
3925            } => {
3926                f(cond)?;
3927                f(then)?;
3928                f(els)?;
3929            }
3930            Exists(..) | Select(..) => (),
3931            Windowing(expr, _name) => expr.try_visit_children(f)?,
3932        }
3933        Ok(())
3934    }
3935
3936    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3937    where
3938        F: FnMut(&mut Self) -> Result<(), E>,
3939        E: From<RecursionLimitError>,
3940    {
3941        use HirScalarExpr::*;
3942        match self {
3943            Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3944            CallUnary { expr, .. } => f(expr)?,
3945            CallBinary { expr1, expr2, .. } => {
3946                f(expr1)?;
3947                f(expr2)?;
3948            }
3949            CallVariadic { exprs, .. } => {
3950                for expr in exprs {
3951                    f(expr)?;
3952                }
3953            }
3954            If {
3955                cond,
3956                then,
3957                els,
3958                name: _,
3959            } => {
3960                f(cond)?;
3961                f(then)?;
3962                f(els)?;
3963            }
3964            Exists(..) | Select(..) => (),
3965            Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3966        }
3967        Ok(())
3968    }
3969}
3970
3971impl AbstractExpr for HirScalarExpr {
3972    type Type = SqlColumnType;
3973
3974    fn typ(
3975        &self,
3976        outers: &[SqlRelationType],
3977        inner: &SqlRelationType,
3978        params: &BTreeMap<usize, SqlScalarType>,
3979    ) -> Self::Type {
3980        stack::maybe_grow(|| match self {
3981            HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3982                if *level == 0 {
3983                    inner.column_types[*column].clone()
3984                } else {
3985                    outers[*level - 1].column_types[*column].clone()
3986                }
3987            }
3988            HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3989            HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3990            HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(),
3991            HirScalarExpr::CallUnary {
3992                expr,
3993                func,
3994                name: _,
3995            } => func.output_sql_type(expr.typ(outers, inner, params)),
3996            HirScalarExpr::CallBinary {
3997                expr1,
3998                expr2,
3999                func,
4000                name: _,
4001            } => func.output_sql_type(&[
4002                expr1.typ(outers, inner, params),
4003                expr2.typ(outers, inner, params),
4004            ]),
4005            HirScalarExpr::CallVariadic {
4006                exprs,
4007                func,
4008                name: _,
4009            } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
4010            HirScalarExpr::If {
4011                cond: _,
4012                then,
4013                els,
4014                name: _,
4015            } => {
4016                let then_type = then.typ(outers, inner, params);
4017                let else_type = els.typ(outers, inner, params);
4018                then_type.sql_union(&else_type).unwrap() // HIR deliberately not using `union`
4019            }
4020            HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
4021            HirScalarExpr::Select(expr, _name) => {
4022                let mut outers = outers.to_vec();
4023                outers.insert(0, inner.clone());
4024                expr.typ(&outers, params)
4025                    .column_types
4026                    .into_element()
4027                    .nullable(true)
4028            }
4029            HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
4030        })
4031    }
4032}
4033
4034impl AggregateExpr {
4035    pub fn typ(
4036        &self,
4037        outers: &[SqlRelationType],
4038        inner: &SqlRelationType,
4039        params: &BTreeMap<usize, SqlScalarType>,
4040    ) -> SqlColumnType {
4041        self.func
4042            .output_sql_type(self.expr.typ(outers, inner, params))
4043    }
4044
4045    /// Returns whether the expression is COUNT(*) or not.  Note that
4046    /// when we define the count builtin in sql::func, we convert
4047    /// COUNT(*) to COUNT(true), making it indistinguishable from
4048    /// literal COUNT(true), but we prefer to consider this as the
4049    /// former.
4050    ///
4051    /// (MIR has the same `is_count_asterisk`.)
4052    pub fn is_count_asterisk(&self) -> bool {
4053        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
4054    }
4055}