Skip to main content

mz_expr/
scalar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::ops::BitOrAssign;
12use std::sync::Arc;
13use std::{fmt, mem};
14
15use itertools::Itertools;
16use mz_lowertest::MzReflect;
17use mz_ore::cast::CastFrom;
18use mz_ore::collections::CollectionExt;
19use mz_ore::iter::IteratorExt;
20use mz_ore::stack::RecursionLimitError;
21use mz_ore::str::StrExt;
22use mz_ore::treat_as_equal::TreatAsEqual;
23use mz_ore::vec::swap_remove_multiple;
24use mz_pgrepr::TypeFromOidError;
25use mz_pgtz::timezone::TimezoneSpec;
26use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
27use mz_repr::adt::array::InvalidArrayError;
28use mz_repr::adt::date::DateError;
29use mz_repr::adt::datetime::DateTimeUnits;
30use mz_repr::adt::range::InvalidRangeError;
31use mz_repr::adt::regex::{Regex, RegexCompilationError};
32use mz_repr::adt::timestamp::TimestampError;
33use mz_repr::strconv::{ParseError, ParseHexError};
34use mz_repr::{Datum, ReprColumnType, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest::prelude::*;
36use proptest_derive::Arbitrary;
37use serde::{Deserialize, Serialize};
38
39use crate::scalar::func::format::DateTimeFormat;
40use crate::scalar::func::{
41    BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, parse_timezone,
42    regexp_replace_parse_flags,
43};
44use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims;
45use crate::visit::{Visit, VisitChildren};
46
47pub mod func;
48pub mod like_pattern;
49
50include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs"));
51
52#[derive(
53    Clone,
54    PartialEq,
55    Eq,
56    PartialOrd,
57    Ord,
58    Hash,
59    Serialize,
60    Deserialize,
61    MzReflect
62)]
63pub enum MirScalarExpr {
64    /// A column of the input row
65    Column(usize, TreatAsEqual<Option<Arc<str>>>),
66    /// A literal value.
67    /// (Stored as a row, because we can't own a Datum)
68    Literal(Result<Row, EvalError>, SqlColumnType),
69    /// A call to an unmaterializable function.
70    ///
71    /// These functions cannot be evaluated by `MirScalarExpr::eval`. They must
72    /// be transformed away by a higher layer.
73    CallUnmaterializable(UnmaterializableFunc),
74    /// A function call that takes one expression as an argument.
75    CallUnary {
76        func: UnaryFunc,
77        expr: Box<MirScalarExpr>,
78    },
79    /// A function call that takes two expressions as arguments.
80    CallBinary {
81        func: BinaryFunc,
82        expr1: Box<MirScalarExpr>,
83        expr2: Box<MirScalarExpr>,
84    },
85    /// A function call that takes an arbitrary number of arguments.
86    CallVariadic {
87        func: VariadicFunc,
88        exprs: Vec<MirScalarExpr>,
89    },
90    /// Conditionally evaluated expressions.
91    ///
92    /// It is important that `then` and `els` only be evaluated if
93    /// `cond` is true or not, respectively. This is the only way
94    /// users can guard execution (other logical operator do not
95    /// short-circuit) and we need to preserve that.
96    If {
97        cond: Box<MirScalarExpr>,
98        then: Box<MirScalarExpr>,
99        els: Box<MirScalarExpr>,
100    },
101}
102
103// We need a custom Debug because we don't want to show `None` for name information.
104// Sadly, the `derivative` crate doesn't support this use case.
105impl std::fmt::Debug for MirScalarExpr {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            MirScalarExpr::Column(i, TreatAsEqual(Some(name))) => {
109                write!(f, "Column({i}, {name:?})")
110            }
111            MirScalarExpr::Column(i, TreatAsEqual(None)) => write!(f, "Column({i})"),
112            MirScalarExpr::Literal(lit, typ) => write!(f, "Literal({lit:?}, {typ:?})"),
113            MirScalarExpr::CallUnmaterializable(func) => {
114                write!(f, "CallUnmaterializable({func:?})")
115            }
116            MirScalarExpr::CallUnary { func, expr } => {
117                write!(f, "CallUnary({func:?}, {expr:?})")
118            }
119            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
120                write!(f, "CallBinary({func:?}, {expr1:?}, {expr2:?})")
121            }
122            MirScalarExpr::CallVariadic { func, exprs } => {
123                write!(f, "CallVariadic({func:?}, {exprs:?})")
124            }
125            MirScalarExpr::If { cond, then, els } => {
126                write!(f, "If({cond:?}, {then:?}, {els:?})")
127            }
128        }
129    }
130}
131
132impl MirScalarExpr {
133    pub fn columns(is: &[usize]) -> Vec<MirScalarExpr> {
134        is.iter().map(|i| MirScalarExpr::column(*i)).collect()
135    }
136
137    pub fn column(column: usize) -> Self {
138        MirScalarExpr::Column(column, TreatAsEqual(None))
139    }
140
141    pub fn named_column(column: usize, name: Arc<str>) -> Self {
142        MirScalarExpr::Column(column, TreatAsEqual(Some(name)))
143    }
144
145    pub fn literal(res: Result<Datum, EvalError>, typ: SqlScalarType) -> Self {
146        let typ = typ.nullable(matches!(res, Ok(Datum::Null)));
147        let row = res.map(|datum| Row::pack_slice(&[datum]));
148        MirScalarExpr::Literal(row, typ)
149    }
150
151    pub fn literal_ok(datum: Datum, typ: SqlScalarType) -> Self {
152        MirScalarExpr::literal(Ok(datum), typ)
153    }
154
155    pub fn literal_null(typ: SqlScalarType) -> Self {
156        MirScalarExpr::literal_ok(Datum::Null, typ)
157    }
158
159    pub fn literal_false() -> Self {
160        MirScalarExpr::literal_ok(Datum::False, SqlScalarType::Bool)
161    }
162
163    pub fn literal_true() -> Self {
164        MirScalarExpr::literal_ok(Datum::True, SqlScalarType::Bool)
165    }
166
167    pub fn call_unary<U: Into<UnaryFunc>>(self, func: U) -> Self {
168        MirScalarExpr::CallUnary {
169            func: func.into(),
170            expr: Box::new(self),
171        }
172    }
173
174    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
175        MirScalarExpr::CallBinary {
176            func: func.into(),
177            expr1: Box::new(self),
178            expr2: Box::new(other),
179        }
180    }
181
182    pub fn if_then_else(self, t: Self, f: Self) -> Self {
183        MirScalarExpr::If {
184            cond: Box::new(self),
185            then: Box::new(t),
186            els: Box::new(f),
187        }
188    }
189
190    pub fn or(self, other: Self) -> Self {
191        MirScalarExpr::CallVariadic {
192            func: VariadicFunc::Or,
193            exprs: vec![self, other],
194        }
195    }
196
197    pub fn and(self, other: Self) -> Self {
198        MirScalarExpr::CallVariadic {
199            func: VariadicFunc::And,
200            exprs: vec![self, other],
201        }
202    }
203
204    pub fn not(self) -> Self {
205        self.call_unary(UnaryFunc::Not(func::Not))
206    }
207
208    pub fn call_is_null(self) -> Self {
209        self.call_unary(UnaryFunc::IsNull(func::IsNull))
210    }
211
212    /// Match AND or OR on self and get the args. If no match, then interpret self as if it were
213    /// wrapped in a 1-arg AND/OR.
214    pub fn and_or_args(&self, func_to_match: VariadicFunc) -> Vec<MirScalarExpr> {
215        assert!(func_to_match == VariadicFunc::Or || func_to_match == VariadicFunc::And);
216        match self {
217            MirScalarExpr::CallVariadic { func, exprs } if *func == func_to_match => exprs.clone(),
218            _ => vec![self.clone()],
219        }
220    }
221
222    /// Try to match a literal equality involving the given expression on one side.
223    /// Return the (non-null) literal and a bool that indicates whether an inversion was needed.
224    ///
225    /// More specifically:
226    /// If `self` is an equality with a `null` literal on any side, then the match fails!
227    /// Otherwise: for a given `expr`, if `self` is `<expr> = <literal>` or `<literal> = <expr>`
228    /// then return `Some((<literal>, false))`. In addition to just trying to match `<expr>` as it
229    /// is, we also try to remove an invertible function call (such as a cast). If the match
230    /// succeeds with the inversion, then return `Some((<inverted-literal>, true))`. For more
231    /// details on the inversion, see `invert_casts_on_expr_eq_literal_inner`.
232    pub fn expr_eq_literal(&self, expr: &MirScalarExpr) -> Option<(Row, bool)> {
233        if let MirScalarExpr::CallBinary {
234            func: BinaryFunc::Eq(_),
235            expr1,
236            expr2,
237        } = self
238        {
239            if expr1.is_literal_null() || expr2.is_literal_null() {
240                return None;
241            }
242            if let Some(Ok(lit)) = expr1.as_literal_owned() {
243                return Self::expr_eq_literal_inner(expr, lit, expr1, expr2);
244            }
245            if let Some(Ok(lit)) = expr2.as_literal_owned() {
246                return Self::expr_eq_literal_inner(expr, lit, expr2, expr1);
247            }
248        }
249        None
250    }
251
252    fn expr_eq_literal_inner(
253        expr_to_match: &MirScalarExpr,
254        literal: Row,
255        literal_expr: &MirScalarExpr,
256        other_side: &MirScalarExpr,
257    ) -> Option<(Row, bool)> {
258        if other_side == expr_to_match {
259            return Some((literal, false));
260        } else {
261            // expr didn't exactly match. See if we can match it by inverse-casting.
262            let (cast_removed, inv_cast_lit) =
263                Self::invert_casts_on_expr_eq_literal_inner(other_side, literal_expr);
264            if &cast_removed == expr_to_match {
265                if let Some(Ok(inv_cast_lit_row)) = inv_cast_lit.as_literal_owned() {
266                    return Some((inv_cast_lit_row, true));
267                }
268            }
269        }
270        None
271    }
272
273    /// If `self` is `<expr> = <literal>` or `<literal> = <expr>` then
274    /// return `<expr>`. It also tries to remove a cast (or other invertible function call) from
275    /// `<expr>` before returning it, see `invert_casts_on_expr_eq_literal_inner`.
276    pub fn any_expr_eq_literal(&self) -> Option<MirScalarExpr> {
277        if let MirScalarExpr::CallBinary {
278            func: BinaryFunc::Eq(_),
279            expr1,
280            expr2,
281        } = self
282        {
283            if expr1.is_literal() {
284                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
285                return Some(expr);
286            }
287            if expr2.is_literal() {
288                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
289                return Some(expr);
290            }
291        }
292        None
293    }
294
295    /// If the given `MirScalarExpr` is a literal equality where one side is an invertible function
296    /// call, then calls the inverse function on both sides of the equality and returns the modified
297    /// version of the given `MirScalarExpr`. Otherwise, it returns the original expression.
298    /// For more details, see `invert_casts_on_expr_eq_literal_inner`.
299    pub fn invert_casts_on_expr_eq_literal(&self) -> MirScalarExpr {
300        if let MirScalarExpr::CallBinary {
301            func: BinaryFunc::Eq(_),
302            expr1,
303            expr2,
304        } = self
305        {
306            if expr1.is_literal() {
307                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
308                return literal.call_binary(expr, func::Eq);
309            }
310            if expr2.is_literal() {
311                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
312                return literal.call_binary(expr, func::Eq);
313            }
314            // Note: The above return statements should be consistent in whether they put the
315            // literal in expr1 or expr2, for the deduplication in CanonicalizeMfp to work.
316        }
317        self.clone()
318    }
319
320    /// Given an `<expr>` and a `<literal>` that were taken out from `<expr> = <literal>` or
321    /// `<literal> = <expr>`, it tries to simplify the equality by applying the inverse function of
322    /// the outermost function call of `<expr>` (if exists):
323    ///
324    /// `<literal> = func(<inner_expr>)`, where `func` is invertible
325    ///  -->
326    /// `<func^-1(literal)> = <inner_expr>`
327    /// if `func^-1(literal)` doesn't error out, and both `func` and `func^-1` preserve uniqueness.
328    ///
329    /// The return value is the `<inner_expr>` and the literal value that we get by applying the
330    /// inverse function.
331    fn invert_casts_on_expr_eq_literal_inner(
332        expr: &MirScalarExpr,
333        literal: &MirScalarExpr,
334    ) -> (MirScalarExpr, MirScalarExpr) {
335        assert!(matches!(literal, MirScalarExpr::Literal(..)));
336
337        let temp_storage = &RowArena::new();
338        let eval = |e: &MirScalarExpr| {
339            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
340        };
341
342        if let MirScalarExpr::CallUnary {
343            func,
344            expr: inner_expr,
345        } = expr
346        {
347            if let Some(inverse_func) = func.inverse() {
348                // We don't want to remove a function call that doesn't preserve uniqueness, e.g.,
349                // if `f` is a float, we don't want to inverse-cast `f::INT = 0`, because the
350                // inserted int-to-float cast wouldn't be able to invert the rounding.
351                // Also, we don't want to insert a function call that doesn't preserve
352                // uniqueness. E.g., if `a` has an integer type, we don't want to do
353                // a surprise rounding for `WHERE a = 3.14`.
354                if func.preserves_uniqueness() && inverse_func.preserves_uniqueness() {
355                    let lit_inv = eval(&MirScalarExpr::CallUnary {
356                        func: inverse_func,
357                        expr: Box::new(literal.clone()),
358                    });
359                    // The evaluation can error out, e.g., when casting a too large int32 to int16.
360                    // This case is handled by `impossible_literal_equality_because_types`.
361                    if !lit_inv.is_literal_err() {
362                        return (*inner_expr.clone(), lit_inv);
363                    }
364                }
365            }
366        }
367        (expr.clone(), literal.clone())
368    }
369
370    /// Tries to remove a cast (or other invertible function) in the same way as
371    /// `invert_casts_on_expr_eq_literal`, but if calling the inverse function fails on the literal,
372    /// then it deems the equality to be impossible. For example if `a` is a smallint column, then
373    /// it catches `a::integer = 1000000` to be an always false predicate (where the `::integer`
374    /// could have been inserted implicitly).
375    pub fn impossible_literal_equality_because_types(&self) -> bool {
376        if let MirScalarExpr::CallBinary {
377            func: BinaryFunc::Eq(_),
378            expr1,
379            expr2,
380        } = self
381        {
382            if expr1.is_literal() {
383                return Self::impossible_literal_equality_because_types_inner(expr1, expr2);
384            }
385            if expr2.is_literal() {
386                return Self::impossible_literal_equality_because_types_inner(expr2, expr1);
387            }
388        }
389        false
390    }
391
392    fn impossible_literal_equality_because_types_inner(
393        literal: &MirScalarExpr,
394        other_side: &MirScalarExpr,
395    ) -> bool {
396        assert!(matches!(literal, MirScalarExpr::Literal(..)));
397
398        let temp_storage = &RowArena::new();
399        let eval = |e: &MirScalarExpr| {
400            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
401        };
402
403        if let MirScalarExpr::CallUnary { func, .. } = other_side {
404            if let Some(inverse_func) = func.inverse() {
405                if inverse_func.preserves_uniqueness()
406                    && eval(&MirScalarExpr::CallUnary {
407                        func: inverse_func,
408                        expr: Box::new(literal.clone()),
409                    })
410                    .is_literal_err()
411                {
412                    return true;
413                }
414            }
415        }
416
417        false
418    }
419
420    /// Determines if `self` is
421    /// `<expr> < <literal>` or
422    /// `<expr> > <literal>` or
423    /// `<literal> < <expr>` or
424    /// `<literal> > <expr>` or
425    /// `<expr> <= <literal>` or
426    /// `<expr> >= <literal>` or
427    /// `<literal> <= <expr>` or
428    /// `<literal> >= <expr>`.
429    pub fn any_expr_ineq_literal(&self) -> bool {
430        match self {
431            MirScalarExpr::CallBinary {
432                func:
433                    BinaryFunc::Lt(_) | BinaryFunc::Lte(_) | BinaryFunc::Gt(_) | BinaryFunc::Gte(_),
434                expr1,
435                expr2,
436            } => expr1.is_literal() || expr2.is_literal(),
437            _ => false,
438        }
439    }
440
441    /// Rewrites column indices with their value in `permutation`.
442    ///
443    /// This method is applicable even when `permutation` is not a
444    /// strict permutation, and it only needs to have entries for
445    /// each column referenced in `self`.
446    pub fn permute(&mut self, permutation: &[usize]) {
447        self.visit_columns(|c| *c = permutation[*c]);
448    }
449
450    /// Rewrites column indices with their value in `permutation`.
451    ///
452    /// This method is applicable even when `permutation` is not a
453    /// strict permutation, and it only needs to have entries for
454    /// each column referenced in `self`.
455    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
456        self.visit_columns(|c| *c = permutation[c]);
457    }
458
459    /// Visits each column reference and applies `action` to the column.
460    ///
461    /// Useful for remapping columns, or for collecting expression support.
462    pub fn visit_columns<F>(&mut self, mut action: F)
463    where
464        F: FnMut(&mut usize),
465    {
466        self.visit_pre_mut(|e| {
467            if let MirScalarExpr::Column(col, _) = e {
468                action(col);
469            }
470        });
471    }
472
473    pub fn support(&self) -> BTreeSet<usize> {
474        let mut support = BTreeSet::new();
475        self.support_into(&mut support);
476        support
477    }
478
479    pub fn support_into(&self, support: &mut BTreeSet<usize>) {
480        self.visit_pre(|e| {
481            if let MirScalarExpr::Column(i, _) = e {
482                support.insert(*i);
483            }
484        });
485    }
486
487    pub fn take(&mut self) -> Self {
488        mem::replace(self, MirScalarExpr::literal_null(SqlScalarType::String))
489    }
490
491    /// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
492    /// Otherwise, it returns None.
493    pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
494        if let MirScalarExpr::Literal(lit, _column_type) = self {
495            Some(lit.as_ref().map(|row| row.unpack_first()))
496        } else {
497            None
498        }
499    }
500
501    /// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
502    /// literal's Datum only if the expression is a literal, and it's not a literal error.
503    pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
504        self.as_literal().map(|eval_err| eval_err.ok()).flatten()
505    }
506
507    pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
508        if let MirScalarExpr::Literal(lit, _column_type) = self {
509            Some(lit.clone())
510        } else {
511            None
512        }
513    }
514
515    pub fn as_literal_str(&self) -> Option<&str> {
516        match self.as_literal() {
517            Some(Ok(Datum::String(s))) => Some(s),
518            _ => None,
519        }
520    }
521
522    pub fn as_literal_int64(&self) -> Option<i64> {
523        match self.as_literal() {
524            Some(Ok(Datum::Int64(i))) => Some(i),
525            _ => None,
526        }
527    }
528
529    pub fn as_literal_err(&self) -> Option<&EvalError> {
530        self.as_literal().and_then(|lit| lit.err())
531    }
532
533    pub fn is_literal(&self) -> bool {
534        matches!(self, MirScalarExpr::Literal(_, _))
535    }
536
537    pub fn is_literal_true(&self) -> bool {
538        Some(Ok(Datum::True)) == self.as_literal()
539    }
540
541    pub fn is_literal_false(&self) -> bool {
542        Some(Ok(Datum::False)) == self.as_literal()
543    }
544
545    pub fn is_literal_null(&self) -> bool {
546        Some(Ok(Datum::Null)) == self.as_literal()
547    }
548
549    pub fn is_literal_ok(&self) -> bool {
550        matches!(self, MirScalarExpr::Literal(Ok(_), _typ))
551    }
552
553    pub fn is_literal_err(&self) -> bool {
554        matches!(self, MirScalarExpr::Literal(Err(_), _typ))
555    }
556
557    pub fn is_column(&self) -> bool {
558        matches!(self, MirScalarExpr::Column(_col, _name))
559    }
560
561    pub fn is_error_if_null(&self) -> bool {
562        matches!(
563            self,
564            Self::CallVariadic {
565                func: VariadicFunc::ErrorIfNull,
566                ..
567            }
568        )
569    }
570
571    /// If `self` expresses a temporal filter, normalize it to start with `mz_now()` and return
572    /// references.
573    ///
574    /// A temporal filter is an expression of the form `mz_now() <BINOP> <EXPR>`,
575    /// for a restricted set of `BINOP` and `EXPR` that do not themselves contain `mz_now()`.
576    /// Expressions may conform to this once their expressions are swapped.
577    ///
578    /// If the expression is not a temporal filter, it will be unchanged, and the reason for why
579    /// it's not a temporal filter is returned as a string.
580    pub fn as_mut_temporal_filter(&mut self) -> Result<(&BinaryFunc, &mut MirScalarExpr), String> {
581        if !self.contains_temporal() {
582            return Err("Does not involve mz_now()".to_string());
583        }
584        // Supported temporal predicates are exclusively binary operators.
585        if let MirScalarExpr::CallBinary { func, expr1, expr2 } = self {
586            // Attempt to put `LogicalTimestamp` in the first argument position.
587            if !expr1.contains_temporal()
588                && **expr2 == MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
589            {
590                std::mem::swap(expr1, expr2);
591                *func = match func {
592                    BinaryFunc::Eq(_) => func::Eq.into(),
593                    BinaryFunc::Lt(_) => func::Gt.into(),
594                    BinaryFunc::Lte(_) => func::Gte.into(),
595                    BinaryFunc::Gt(_) => func::Lt.into(),
596                    BinaryFunc::Gte(_) => func::Lte.into(),
597                    x => {
598                        return Err(format!("Unsupported binary temporal operation: {:?}", x));
599                    }
600                };
601            }
602
603            // Error if MLT is referenced in an unsupported position.
604            if expr2.contains_temporal()
605                || **expr1 != MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
606            {
607                return Err(format!(
608                    "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a mz_timestamp-castable expression. Expression found: {}",
609                    MirScalarExpr::CallBinary {
610                        func: func.clone(),
611                        expr1: expr1.clone(),
612                        expr2: expr2.clone()
613                    },
614                ));
615            }
616
617            Ok((&*func, expr2))
618        } else {
619            Err(format!(
620                "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a non-temporal expression of mz_timestamp-castable type. Expression found: {}",
621                self,
622            ))
623        }
624    }
625
626    #[deprecated = "Use `might_error` instead"]
627    pub fn contains_error_if_null(&self) -> bool {
628        let mut worklist = vec![self];
629        while let Some(expr) = worklist.pop() {
630            if expr.is_error_if_null() {
631                return true;
632            }
633            worklist.extend(expr.children());
634        }
635        false
636    }
637
638    pub fn contains_err(&self) -> bool {
639        let mut worklist = vec![self];
640        while let Some(expr) = worklist.pop() {
641            if expr.is_literal_err() {
642                return true;
643            }
644            worklist.extend(expr.children());
645        }
646        false
647    }
648
649    /// A very crude approximation for scalar expressions that might produce an
650    /// error.
651    ///
652    /// Currently, this is restricted only to expressions that either contain a
653    /// literal error or a [`VariadicFunc::ErrorIfNull`] call.
654    pub fn might_error(&self) -> bool {
655        let mut worklist = vec![self];
656        while let Some(expr) = worklist.pop() {
657            if expr.is_literal_err() || expr.is_error_if_null() {
658                return true;
659            }
660            worklist.extend(expr.children());
661        }
662        false
663    }
664
665    /// If self is a column, return the column index, otherwise `None`.
666    pub fn as_column(&self) -> Option<usize> {
667        if let MirScalarExpr::Column(c, _) = self {
668            Some(*c)
669        } else {
670            None
671        }
672    }
673
674    /// Reduces a complex expression where possible.
675    ///
676    /// This function uses nullability information present in `column_types`,
677    /// and the result may only continue to be a correct transformation as
678    /// long as this information continues to hold (nullability may not hold
679    /// as expressions migrate around).
680    ///
681    /// (If you'd like to not use nullability information here, then you can
682    /// tweak the nullabilities in `column_types` before passing it to this
683    /// function, see e.g. in `EquivalenceClasses::minimize`.)
684    ///
685    /// Also performs partial canonicalization on the expression.
686    ///
687    /// ```rust
688    /// use mz_expr::MirScalarExpr;
689    /// use mz_repr::{SqlColumnType, Datum, SqlScalarType};
690    ///
691    /// let expr_0 = MirScalarExpr::column(0);
692    /// let expr_t = MirScalarExpr::literal_true();
693    /// let expr_f = MirScalarExpr::literal_false();
694    ///
695    /// let mut test =
696    /// expr_t
697    ///     .clone()
698    ///     .and(expr_f.clone())
699    ///     .if_then_else(expr_0, expr_t.clone());
700    ///
701    /// let input_type = vec![SqlScalarType::Int32.nullable(false)];
702    /// test.reduce(&input_type);
703    /// assert_eq!(test, expr_t);
704    /// ```
705    pub fn reduce(&mut self, column_types: &[SqlColumnType]) {
706        let temp_storage = &RowArena::new();
707        let eval = |e: &MirScalarExpr| {
708            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type)
709        };
710
711        // Simplifications run in a loop until `self` no longer changes.
712        let mut old_self = MirScalarExpr::column(0);
713        while old_self != *self {
714            old_self = self.clone();
715            #[allow(deprecated)]
716            self.visit_mut_pre_post_nolimit(
717                &mut |e| {
718                    match e {
719                        MirScalarExpr::CallUnary { func, expr } => {
720                            if *func == UnaryFunc::IsNull(func::IsNull) {
721                                if !expr.typ(column_types).nullable {
722                                    *e = MirScalarExpr::literal_false();
723                                } else {
724                                    // Try to at least decompose IsNull into a disjunction
725                                    // of simpler IsNull subexpressions.
726                                    if let Some(expr) = expr.decompose_is_null() {
727                                        *e = expr
728                                    }
729                                }
730                            } else if *func == UnaryFunc::Not(func::Not) {
731                                // Push down not expressions
732                                match &mut **expr {
733                                    // Two negates cancel each other out.
734                                    MirScalarExpr::CallUnary {
735                                        expr: inner_expr,
736                                        func: UnaryFunc::Not(func::Not),
737                                    } => *e = inner_expr.take(),
738                                    // Transforms `NOT(a <op> b)` to `a negate(<op>) b`
739                                    // if a negation exists.
740                                    MirScalarExpr::CallBinary { expr1, expr2, func } => {
741                                        if let Some(negated_func) = func.negate() {
742                                            *e = MirScalarExpr::CallBinary {
743                                                expr1: Box::new(expr1.take()),
744                                                expr2: Box::new(expr2.take()),
745                                                func: negated_func,
746                                            }
747                                        }
748                                    }
749                                    MirScalarExpr::CallVariadic { .. } => {
750                                        e.demorgans();
751                                    }
752                                    _ => {}
753                                }
754                            }
755                        }
756                        _ => {}
757                    };
758                    None
759                },
760                &mut |e| match e {
761                    // Evaluate and pull up constants
762                    MirScalarExpr::Column(_, _)
763                    | MirScalarExpr::Literal(_, _)
764                    | MirScalarExpr::CallUnmaterializable(_) => (),
765                    MirScalarExpr::CallUnary { func, expr } => {
766                        if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
767                            *e = eval(e);
768                        } else if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
769                            if let MirScalarExpr::CallVariadic {
770                                func: VariadicFunc::RecordCreate { .. },
771                                exprs,
772                            } = &mut **expr
773                            {
774                                *e = exprs.swap_remove(i);
775                            }
776                        }
777                    }
778                    MirScalarExpr::CallBinary { func, expr1, expr2 } => {
779                        if expr1.is_literal() && expr2.is_literal() {
780                            *e = eval(e);
781                        } else if (expr1.is_literal_null() || expr2.is_literal_null())
782                            && func.propagates_nulls()
783                        {
784                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
785                        } else if let Some(err) = expr1.as_literal_err() {
786                            *e = MirScalarExpr::literal(
787                                Err(err.clone()),
788                                e.typ(column_types).scalar_type,
789                            );
790                        } else if let Some(err) = expr2.as_literal_err() {
791                            *e = MirScalarExpr::literal(
792                                Err(err.clone()),
793                                e.typ(column_types).scalar_type,
794                            );
795                        } else if let BinaryFunc::IsLikeMatchCaseInsensitive(_) = func {
796                            if expr2.is_literal() {
797                                // We can at least precompile the regex.
798                                let pattern = expr2.as_literal_str().unwrap();
799                                *e = match like_pattern::compile(pattern, true) {
800                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
801                                        func::IsLikeMatch(matcher),
802                                    )),
803                                    Err(err) => MirScalarExpr::literal(
804                                        Err(err),
805                                        e.typ(column_types).scalar_type,
806                                    ),
807                                };
808                            }
809                        } else if let BinaryFunc::IsLikeMatchCaseSensitive(_) = func {
810                            if expr2.is_literal() {
811                                // We can at least precompile the regex.
812                                let pattern = expr2.as_literal_str().unwrap();
813                                *e = match like_pattern::compile(pattern, false) {
814                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
815                                        func::IsLikeMatch(matcher),
816                                    )),
817                                    Err(err) => MirScalarExpr::literal(
818                                        Err(err),
819                                        e.typ(column_types).scalar_type,
820                                    ),
821                                };
822                            }
823                        } else if matches!(
824                            func,
825                            BinaryFunc::IsRegexpMatchCaseSensitive(_)
826                                | BinaryFunc::IsRegexpMatchCaseInsensitive(_)
827                        ) {
828                            let case_insensitive =
829                                matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
830                            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
831                                *e = match Regex::new(
832                                    row.unpack_first().unwrap_str(),
833                                    case_insensitive,
834                                ) {
835                                    Ok(regex) => expr1.take().call_unary(UnaryFunc::IsRegexpMatch(
836                                        func::IsRegexpMatch(regex),
837                                    )),
838                                    Err(err) => MirScalarExpr::literal(
839                                        Err(err.into()),
840                                        e.typ(column_types).scalar_type,
841                                    ),
842                                };
843                            }
844                        } else if let BinaryFunc::ExtractInterval(_) = *func
845                            && expr1.is_literal()
846                        {
847                            let units = expr1.as_literal_str().unwrap();
848                            *e = match units.parse::<DateTimeUnits>() {
849                                Ok(units) => MirScalarExpr::CallUnary {
850                                    func: UnaryFunc::ExtractInterval(func::ExtractInterval(units)),
851                                    expr: Box::new(expr2.take()),
852                                },
853                                Err(_) => MirScalarExpr::literal(
854                                    Err(EvalError::UnknownUnits(units.into())),
855                                    e.typ(column_types).scalar_type,
856                                ),
857                            }
858                        } else if let BinaryFunc::ExtractTime(_) = *func
859                            && expr1.is_literal()
860                        {
861                            let units = expr1.as_literal_str().unwrap();
862                            *e = match units.parse::<DateTimeUnits>() {
863                                Ok(units) => MirScalarExpr::CallUnary {
864                                    func: UnaryFunc::ExtractTime(func::ExtractTime(units)),
865                                    expr: Box::new(expr2.take()),
866                                },
867                                Err(_) => MirScalarExpr::literal(
868                                    Err(EvalError::UnknownUnits(units.into())),
869                                    e.typ(column_types).scalar_type,
870                                ),
871                            }
872                        } else if let BinaryFunc::ExtractTimestamp(_) = *func
873                            && expr1.is_literal()
874                        {
875                            let units = expr1.as_literal_str().unwrap();
876                            *e = match units.parse::<DateTimeUnits>() {
877                                Ok(units) => MirScalarExpr::CallUnary {
878                                    func: UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(
879                                        units,
880                                    )),
881                                    expr: Box::new(expr2.take()),
882                                },
883                                Err(_) => MirScalarExpr::literal(
884                                    Err(EvalError::UnknownUnits(units.into())),
885                                    e.typ(column_types).scalar_type,
886                                ),
887                            }
888                        } else if let BinaryFunc::ExtractTimestampTz(_) = *func
889                            && expr1.is_literal()
890                        {
891                            let units = expr1.as_literal_str().unwrap();
892                            *e = match units.parse::<DateTimeUnits>() {
893                                Ok(units) => MirScalarExpr::CallUnary {
894                                    func: UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(
895                                        units,
896                                    )),
897                                    expr: Box::new(expr2.take()),
898                                },
899                                Err(_) => MirScalarExpr::literal(
900                                    Err(EvalError::UnknownUnits(units.into())),
901                                    e.typ(column_types).scalar_type,
902                                ),
903                            }
904                        } else if let BinaryFunc::ExtractDate(_) = *func
905                            && expr1.is_literal()
906                        {
907                            let units = expr1.as_literal_str().unwrap();
908                            *e = match units.parse::<DateTimeUnits>() {
909                                Ok(units) => MirScalarExpr::CallUnary {
910                                    func: UnaryFunc::ExtractDate(func::ExtractDate(units)),
911                                    expr: Box::new(expr2.take()),
912                                },
913                                Err(_) => MirScalarExpr::literal(
914                                    Err(EvalError::UnknownUnits(units.into())),
915                                    e.typ(column_types).scalar_type,
916                                ),
917                            }
918                        } else if let BinaryFunc::DatePartInterval(_) = *func
919                            && expr1.is_literal()
920                        {
921                            let units = expr1.as_literal_str().unwrap();
922                            *e = match units.parse::<DateTimeUnits>() {
923                                Ok(units) => MirScalarExpr::CallUnary {
924                                    func: UnaryFunc::DatePartInterval(func::DatePartInterval(
925                                        units,
926                                    )),
927                                    expr: Box::new(expr2.take()),
928                                },
929                                Err(_) => MirScalarExpr::literal(
930                                    Err(EvalError::UnknownUnits(units.into())),
931                                    e.typ(column_types).scalar_type,
932                                ),
933                            }
934                        } else if let BinaryFunc::DatePartTime(_) = *func
935                            && expr1.is_literal()
936                        {
937                            let units = expr1.as_literal_str().unwrap();
938                            *e = match units.parse::<DateTimeUnits>() {
939                                Ok(units) => MirScalarExpr::CallUnary {
940                                    func: UnaryFunc::DatePartTime(func::DatePartTime(units)),
941                                    expr: Box::new(expr2.take()),
942                                },
943                                Err(_) => MirScalarExpr::literal(
944                                    Err(EvalError::UnknownUnits(units.into())),
945                                    e.typ(column_types).scalar_type,
946                                ),
947                            }
948                        } else if let BinaryFunc::DatePartTimestamp(_) = *func
949                            && expr1.is_literal()
950                        {
951                            let units = expr1.as_literal_str().unwrap();
952                            *e = match units.parse::<DateTimeUnits>() {
953                                Ok(units) => MirScalarExpr::CallUnary {
954                                    func: UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(
955                                        units,
956                                    )),
957                                    expr: Box::new(expr2.take()),
958                                },
959                                Err(_) => MirScalarExpr::literal(
960                                    Err(EvalError::UnknownUnits(units.into())),
961                                    e.typ(column_types).scalar_type,
962                                ),
963                            }
964                        } else if let BinaryFunc::DatePartTimestampTz(_) = *func
965                            && expr1.is_literal()
966                        {
967                            let units = expr1.as_literal_str().unwrap();
968                            *e = match units.parse::<DateTimeUnits>() {
969                                Ok(units) => MirScalarExpr::CallUnary {
970                                    func: UnaryFunc::DatePartTimestampTz(
971                                        func::DatePartTimestampTz(units),
972                                    ),
973                                    expr: Box::new(expr2.take()),
974                                },
975                                Err(_) => MirScalarExpr::literal(
976                                    Err(EvalError::UnknownUnits(units.into())),
977                                    e.typ(column_types).scalar_type,
978                                ),
979                            }
980                        } else if let BinaryFunc::DateTruncTimestamp(_) = *func
981                            && expr1.is_literal()
982                        {
983                            let units = expr1.as_literal_str().unwrap();
984                            *e = match units.parse::<DateTimeUnits>() {
985                                Ok(units) => MirScalarExpr::CallUnary {
986                                    func: UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(
987                                        units,
988                                    )),
989                                    expr: Box::new(expr2.take()),
990                                },
991                                Err(_) => MirScalarExpr::literal(
992                                    Err(EvalError::UnknownUnits(units.into())),
993                                    e.typ(column_types).scalar_type,
994                                ),
995                            }
996                        } else if let BinaryFunc::DateTruncTimestampTz(_) = *func
997                            && expr1.is_literal()
998                        {
999                            let units = expr1.as_literal_str().unwrap();
1000                            *e = match units.parse::<DateTimeUnits>() {
1001                                Ok(units) => MirScalarExpr::CallUnary {
1002                                    func: UnaryFunc::DateTruncTimestampTz(
1003                                        func::DateTruncTimestampTz(units),
1004                                    ),
1005                                    expr: Box::new(expr2.take()),
1006                                },
1007                                Err(_) => MirScalarExpr::literal(
1008                                    Err(EvalError::UnknownUnits(units.into())),
1009                                    e.typ(column_types).scalar_type,
1010                                ),
1011                            }
1012                        } else if matches!(func, BinaryFunc::TimezoneTimestampBinary(_))
1013                            && expr1.is_literal()
1014                        {
1015                            // If the timezone argument is a literal, and we're applying the function on many rows at the same
1016                            // time we really don't want to parse it again and again, so we parse it once and embed it into the
1017                            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
1018                            let tz = expr1.as_literal_str().unwrap();
1019                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1020                                Ok(tz) => MirScalarExpr::CallUnary {
1021                                    func: UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz)),
1022                                    expr: Box::new(expr2.take()),
1023                                },
1024                                Err(err) => MirScalarExpr::literal(
1025                                    Err(err),
1026                                    e.typ(column_types).scalar_type,
1027                                ),
1028                            }
1029                        } else if matches!(func, BinaryFunc::TimezoneTimestampTzBinary(_))
1030                            && expr1.is_literal()
1031                        {
1032                            let tz = expr1.as_literal_str().unwrap();
1033                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1034                                Ok(tz) => MirScalarExpr::CallUnary {
1035                                    func: UnaryFunc::TimezoneTimestampTz(
1036                                        func::TimezoneTimestampTz(tz),
1037                                    ),
1038                                    expr: Box::new(expr2.take()),
1039                                },
1040                                Err(err) => MirScalarExpr::literal(
1041                                    Err(err),
1042                                    e.typ(column_types).scalar_type,
1043                                ),
1044                            }
1045                        } else if let BinaryFunc::ToCharTimestamp(_) = *func
1046                            && expr2.is_literal()
1047                        {
1048                            let format_str = expr2.as_literal_str().unwrap();
1049                            *e = MirScalarExpr::CallUnary {
1050                                func: UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
1051                                    format_string: format_str.to_string(),
1052                                    format: DateTimeFormat::compile(format_str),
1053                                }),
1054                                expr: Box::new(expr1.take()),
1055                            };
1056                        } else if let BinaryFunc::ToCharTimestampTz(_) = *func
1057                            && expr2.is_literal()
1058                        {
1059                            let format_str = expr2.as_literal_str().unwrap();
1060                            *e = MirScalarExpr::CallUnary {
1061                                func: UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
1062                                    format_string: format_str.to_string(),
1063                                    format: DateTimeFormat::compile(format_str),
1064                                }),
1065                                expr: Box::new(expr1.take()),
1066                            };
1067                        } else if matches!(*func, BinaryFunc::Eq(_) | BinaryFunc::NotEq(_))
1068                            && expr2 < expr1
1069                        {
1070                            // Canonically order elements so that deduplication works better.
1071                            // Also, the below `Literal([c1, c2]) = record_create(e1, e2)` matching
1072                            // relies on this canonical ordering.
1073                            mem::swap(expr1, expr2);
1074                        } else if let (
1075                            BinaryFunc::Eq(_),
1076                            MirScalarExpr::Literal(
1077                                Ok(lit_row),
1078                                SqlColumnType {
1079                                    scalar_type:
1080                                        SqlScalarType::Record {
1081                                            fields: field_types,
1082                                            ..
1083                                        },
1084                                    ..
1085                                },
1086                            ),
1087                            MirScalarExpr::CallVariadic {
1088                                func: VariadicFunc::RecordCreate { .. },
1089                                exprs: rec_create_args,
1090                            },
1091                        ) = (&*func, &**expr1, &**expr2)
1092                        {
1093                            // Literal([c1, c2]) = record_create(e1, e2)
1094                            //  -->
1095                            // c1 = e1 AND c2 = e2
1096                            //
1097                            // (Records are represented as lists.)
1098                            //
1099                            // `MapFilterProject::literal_constraints` relies on this transform,
1100                            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
1101                            match lit_row.unpack_first() {
1102                                Datum::List(datum_list) => {
1103                                    *e = MirScalarExpr::CallVariadic {
1104                                        func: VariadicFunc::And,
1105                                        exprs: datum_list
1106                                            .iter()
1107                                            .zip_eq(field_types)
1108                                            .zip_eq(rec_create_args)
1109                                            .map(|((d, (_, typ)), a)| {
1110                                                MirScalarExpr::literal_ok(
1111                                                    d,
1112                                                    typ.scalar_type.clone(),
1113                                                )
1114                                                .call_binary(a.clone(), func::Eq)
1115                                            })
1116                                            .collect(),
1117                                    };
1118                                }
1119                                _ => {}
1120                            }
1121                        } else if let (
1122                            BinaryFunc::Eq(_),
1123                            MirScalarExpr::CallVariadic {
1124                                func: VariadicFunc::RecordCreate { .. },
1125                                exprs: rec_create_args1,
1126                            },
1127                            MirScalarExpr::CallVariadic {
1128                                func: VariadicFunc::RecordCreate { .. },
1129                                exprs: rec_create_args2,
1130                            },
1131                        ) = (&*func, &**expr1, &**expr2)
1132                        {
1133                            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
1134                            //  -->
1135                            // a1 = b1 AND a2 = b2 AND ...
1136                            //
1137                            // This is similar to the previous reduction, but this one kicks in also
1138                            // when only some (or none) of the record fields are literals. This
1139                            // enables the discovery of literal constraints for those fields.
1140                            //
1141                            // Note that there is a similar decomposition in
1142                            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
1143                            // pipeline than the compilation of IN lists to `record_create`.
1144                            *e = MirScalarExpr::CallVariadic {
1145                                func: VariadicFunc::And,
1146                                exprs: rec_create_args1
1147                                    .into_iter()
1148                                    .zip_eq(rec_create_args2)
1149                                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
1150                                    .collect(),
1151                            }
1152                        }
1153                    }
1154                    MirScalarExpr::CallVariadic { .. } => {
1155                        e.flatten_associative();
1156                        let (func, exprs) = match e {
1157                            MirScalarExpr::CallVariadic { func, exprs } => (func, exprs),
1158                            _ => unreachable!("`flatten_associative` shouldn't change node type"),
1159                        };
1160                        if *func == VariadicFunc::Coalesce {
1161                            // If all inputs are null, output is null. This check must
1162                            // be done before `exprs.retain...` because `e.typ` requires
1163                            // > 0 `exprs` remain.
1164                            if exprs.iter().all(|expr| expr.is_literal_null()) {
1165                                *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1166                                return;
1167                            }
1168
1169                            // Remove any null values if not all values are null.
1170                            exprs.retain(|e| !e.is_literal_null());
1171
1172                            // Find the first argument that is a literal or non-nullable
1173                            // column. All arguments after it get ignored, so throw them
1174                            // away. This intentionally throws away errors that can
1175                            // never happen.
1176                            if let Some(i) = exprs
1177                                .iter()
1178                                .position(|e| e.is_literal() || !e.typ(column_types).nullable)
1179                            {
1180                                exprs.truncate(i + 1);
1181                            }
1182
1183                            // Deduplicate arguments in cases like `coalesce(#0, #0)`.
1184                            let mut prior_exprs = BTreeSet::new();
1185                            exprs.retain(|e| prior_exprs.insert(e.clone()));
1186
1187                            if exprs.len() == 1 {
1188                                // Only one argument, so the coalesce is a no-op.
1189                                *e = exprs[0].take();
1190                            }
1191                        } else if exprs.iter().all(|e| e.is_literal()) {
1192                            *e = eval(e);
1193                        } else if func.propagates_nulls()
1194                            && exprs.iter().any(|e| e.is_literal_null())
1195                        {
1196                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1197                        } else if let Some(err) = exprs.iter().find_map(|e| e.as_literal_err()) {
1198                            *e = MirScalarExpr::literal(
1199                                Err(err.clone()),
1200                                e.typ(column_types).scalar_type,
1201                            );
1202                        } else if *func == VariadicFunc::RegexpMatch
1203                            && exprs[1].is_literal()
1204                            && exprs.get(2).map_or(true, |e| e.is_literal())
1205                        {
1206                            let needle = exprs[1].as_literal_str().unwrap();
1207                            let flags = match exprs.len() {
1208                                3 => exprs[2].as_literal_str().unwrap(),
1209                                _ => "",
1210                            };
1211                            *e = match func::build_regex(needle, flags) {
1212                                Ok(regex) => mem::take(exprs)
1213                                    .into_first()
1214                                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
1215                                Err(err) => MirScalarExpr::literal(
1216                                    Err(err),
1217                                    e.typ(column_types).scalar_type,
1218                                ),
1219                            };
1220                        } else if *func == VariadicFunc::RegexpReplace
1221                            && exprs[1].is_literal()
1222                            && exprs.get(3).map_or(true, |e| e.is_literal())
1223                        {
1224                            let pattern = exprs[1].as_literal_str().unwrap();
1225                            let flags = exprs
1226                                .get(3)
1227                                .map_or("", |expr| expr.as_literal_str().unwrap());
1228                            let (limit, flags) = regexp_replace_parse_flags(flags);
1229
1230                            // The behavior of `regexp_replace` is that if the data is `NULL`, the
1231                            // function returns `NULL`, independently of whether the pattern or
1232                            // flags are correct. We need to check for this case and introduce an
1233                            // if-then-else on the error path to only surface the error if the first
1234                            // input is non-NULL.
1235                            *e = match func::build_regex(pattern, &flags) {
1236                                Ok(regex) => {
1237                                    let mut exprs = mem::take(exprs);
1238                                    let replacement = exprs.swap_remove(2);
1239                                    let source = exprs.swap_remove(0);
1240                                    source.call_binary(
1241                                        replacement,
1242                                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
1243                                    )
1244                                }
1245                                Err(err) => {
1246                                    let mut exprs = mem::take(exprs);
1247                                    let source = exprs.swap_remove(0);
1248                                    let scalar_type = e.typ(column_types).scalar_type;
1249                                    // We need to return `NULL` on `NULL` input, and error otherwise.
1250                                    source.call_is_null().if_then_else(
1251                                        MirScalarExpr::literal_null(scalar_type.clone()),
1252                                        MirScalarExpr::literal(Err(err), scalar_type),
1253                                    )
1254                                }
1255                            };
1256                        } else if *func == VariadicFunc::RegexpSplitToArray
1257                            && exprs[1].is_literal()
1258                            && exprs.get(2).map_or(true, |e| e.is_literal())
1259                        {
1260                            let needle = exprs[1].as_literal_str().unwrap();
1261                            let flags = match exprs.len() {
1262                                3 => exprs[2].as_literal_str().unwrap(),
1263                                _ => "",
1264                            };
1265                            *e = match func::build_regex(needle, flags) {
1266                                Ok(regex) => mem::take(exprs).into_first().call_unary(
1267                                    UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(regex)),
1268                                ),
1269                                Err(err) => MirScalarExpr::literal(
1270                                    Err(err),
1271                                    e.typ(column_types).scalar_type,
1272                                ),
1273                            };
1274                        } else if *func == VariadicFunc::ListIndex && is_list_create_call(&exprs[0])
1275                        {
1276                            // We are looking for ListIndex(ListCreate, literal), and eliminate
1277                            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
1278                            let ind_exprs = exprs.split_off(1);
1279                            let top_list_create = exprs.swap_remove(0);
1280                            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
1281                        } else if *func == VariadicFunc::Or || *func == VariadicFunc::And {
1282                            // Note: It's important that we have called `flatten_associative` above.
1283                            e.undistribute_and_or();
1284                            e.reduce_and_canonicalize_and_or();
1285                        } else if let VariadicFunc::TimezoneTime = func {
1286                            if exprs[0].is_literal() && exprs[2].is_literal_ok() {
1287                                let tz = exprs[0].as_literal_str().unwrap();
1288                                *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1289                                    Ok(tz) => MirScalarExpr::CallUnary {
1290                                        func: UnaryFunc::TimezoneTime(func::TimezoneTime {
1291                                            tz,
1292                                            wall_time: exprs[2]
1293                                                .as_literal()
1294                                                .unwrap()
1295                                                .unwrap()
1296                                                .unwrap_timestamptz()
1297                                                .naive_utc(),
1298                                        }),
1299                                        expr: Box::new(exprs[1].take()),
1300                                    },
1301                                    Err(err) => MirScalarExpr::literal(
1302                                        Err(err),
1303                                        e.typ(column_types).scalar_type,
1304                                    ),
1305                                }
1306                            }
1307                        }
1308                    }
1309                    MirScalarExpr::If { cond, then, els } => {
1310                        if let Some(literal) = cond.as_literal() {
1311                            match literal {
1312                                Ok(Datum::True) => *e = then.take(),
1313                                Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
1314                                Err(err) => {
1315                                    *e = MirScalarExpr::Literal(
1316                                        Err(err.clone()),
1317                                        then.typ(column_types).union(&els.typ(column_types)),
1318                                    )
1319                                }
1320                                _ => unreachable!(),
1321                            }
1322                        } else if then == els {
1323                            *e = then.take();
1324                        } else if then.is_literal_ok()
1325                            && els.is_literal_ok()
1326                            && then.typ(column_types).scalar_type == SqlScalarType::Bool
1327                            && els.typ(column_types).scalar_type == SqlScalarType::Bool
1328                        {
1329                            match (then.as_literal(), els.as_literal()) {
1330                                // Note: NULLs from the condition should not be propagated to the result
1331                                // of the expression.
1332                                (Some(Ok(Datum::True)), _) => {
1333                                    // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
1334                                    // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
1335                                    *e = cond
1336                                        .clone()
1337                                        .call_is_null()
1338                                        .not()
1339                                        .and(cond.take())
1340                                        .or(els.take());
1341                                }
1342                                (Some(Ok(Datum::False)), _) => {
1343                                    // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
1344                                    // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
1345                                    *e = cond
1346                                        .clone()
1347                                        .not()
1348                                        .or(cond.take().call_is_null())
1349                                        .and(els.take());
1350                                }
1351                                (_, Some(Ok(Datum::True))) => {
1352                                    // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
1353                                    // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
1354                                    *e = cond
1355                                        .clone()
1356                                        .not()
1357                                        .or(cond.take().call_is_null())
1358                                        .or(then.take());
1359                                }
1360                                (_, Some(Ok(Datum::False))) => {
1361                                    // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
1362                                    // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
1363                                    *e = cond
1364                                        .clone()
1365                                        .call_is_null()
1366                                        .not()
1367                                        .and(cond.take())
1368                                        .and(then.take());
1369                                }
1370                                _ => {}
1371                            }
1372                        } else {
1373                            // Equivalent expression structure would allow us to push the `If` into the expression.
1374                            // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
1375                            //
1376                            // We have to also make sure that the expressions that will end up in
1377                            // the two `If` branches have unionable types. Otherwise, the `If` could
1378                            // not be typed by `typ`. An example where this could cause an issue is
1379                            // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
1380                            // range of input types. (In theory, we could still do the optimization
1381                            // in this case by inserting appropriate casts, but this corner case is
1382                            // not worth the complication for now.)
1383                            // See https://github.com/MaterializeInc/database-issues/issues/9182
1384                            match (&mut **then, &mut **els) {
1385                                (
1386                                    MirScalarExpr::CallUnary { func: f1, expr: e1 },
1387                                    MirScalarExpr::CallUnary { func: f2, expr: e2 },
1388                                ) if f1 == f2
1389                                    && e1
1390                                        .typ(column_types)
1391                                        .try_union(&e2.typ(column_types))
1392                                        .is_ok() =>
1393                                {
1394                                    *e = cond
1395                                        .take()
1396                                        .if_then_else(e1.take(), e2.take())
1397                                        .call_unary(f1.clone());
1398                                }
1399                                (
1400                                    MirScalarExpr::CallBinary {
1401                                        func: f1,
1402                                        expr1: e1a,
1403                                        expr2: e2a,
1404                                    },
1405                                    MirScalarExpr::CallBinary {
1406                                        func: f2,
1407                                        expr1: e1b,
1408                                        expr2: e2b,
1409                                    },
1410                                ) if f1 == f2
1411                                    && e1a == e1b
1412                                    && e2a
1413                                        .typ(column_types)
1414                                        .try_union(&e2b.typ(column_types))
1415                                        .is_ok() =>
1416                                {
1417                                    *e = e1a.take().call_binary(
1418                                        cond.take().if_then_else(e2a.take(), e2b.take()),
1419                                        f1.clone(),
1420                                    );
1421                                }
1422                                (
1423                                    MirScalarExpr::CallBinary {
1424                                        func: f1,
1425                                        expr1: e1a,
1426                                        expr2: e2a,
1427                                    },
1428                                    MirScalarExpr::CallBinary {
1429                                        func: f2,
1430                                        expr1: e1b,
1431                                        expr2: e2b,
1432                                    },
1433                                ) if f1 == f2
1434                                    && e2a == e2b
1435                                    && e1a
1436                                        .typ(column_types)
1437                                        .try_union(&e1b.typ(column_types))
1438                                        .is_ok() =>
1439                                {
1440                                    *e = cond
1441                                        .take()
1442                                        .if_then_else(e1a.take(), e1b.take())
1443                                        .call_binary(e2a.take(), f1.clone());
1444                                }
1445                                _ => {}
1446                            }
1447                        }
1448                    }
1449                },
1450            );
1451        }
1452
1453        /* #region `reduce_list_create_list_index_literal` and helper functions */
1454
1455        fn list_create_type(list_create: &MirScalarExpr) -> SqlScalarType {
1456            if let MirScalarExpr::CallVariadic {
1457                func: VariadicFunc::ListCreate { elem_type: typ },
1458                ..
1459            } = list_create
1460            {
1461                (*typ).clone()
1462            } else {
1463                unreachable!()
1464            }
1465        }
1466
1467        fn is_list_create_call(expr: &MirScalarExpr) -> bool {
1468            matches!(
1469                expr,
1470                MirScalarExpr::CallVariadic {
1471                    func: VariadicFunc::ListCreate { .. },
1472                    ..
1473                }
1474            )
1475        }
1476
1477        /// Partial-evaluates a list indexing with a literal directly after a list creation.
1478        ///
1479        /// Multi-dimensional lists are handled by a single call to this function, with multiple
1480        /// elements in index_exprs (of which not all need to be literals), and nested ListCreates
1481        /// in list_create_to_reduce.
1482        ///
1483        /// # Examples
1484        ///
1485        /// `LIST[f1,f2][2]` --> `f2`.
1486        ///
1487        /// A multi-dimensional list, with only some of the indexes being literals:
1488        /// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
1489        ///
1490        /// See more examples in list.slt.
1491        fn reduce_list_create_list_index_literal(
1492            mut list_create_to_reduce: MirScalarExpr,
1493            mut index_exprs: Vec<MirScalarExpr>,
1494        ) -> MirScalarExpr {
1495            // We iterate over the index_exprs and remove literals, but keep non-literals.
1496            // When we encounter a non-literal, we need to dig into the nested ListCreates:
1497            // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
1498            // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
1499            // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
1500            // that are at the current level (except those that disappeared due to
1501            // literals at earlier levels), index into them with the literal, and change each
1502            // element in `list_create_mut_refs` to the result.
1503            // We also record mut refs to all the earlier `element_type` references that we have
1504            // seen in ListCreate calls, because when we process a literal index, we need to remove
1505            // one layer of list type from all these earlier ListCreate `element_type`s.
1506            let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
1507            let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
1508            let mut i = 0;
1509            while i < index_exprs.len()
1510                && list_create_mut_refs
1511                    .iter()
1512                    .all(|lc| is_list_create_call(lc))
1513            {
1514                if index_exprs[i].is_literal_ok() {
1515                    // We can remove this index.
1516                    let removed_index = index_exprs.remove(i);
1517                    let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
1518                        Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
1519                        _ => unreachable!(), // always an Int64, see plan_index_list
1520                    };
1521                    // For each list_create referenced by list_create_mut_refs, substitute it by its
1522                    // `index`th argument (or null).
1523                    for list_create in &mut list_create_mut_refs {
1524                        let list_create_args = match list_create {
1525                            MirScalarExpr::CallVariadic {
1526                                func: VariadicFunc::ListCreate { elem_type: _ },
1527                                exprs,
1528                            } => exprs,
1529                            _ => unreachable!(), // func cannot be anything else than a ListCreate
1530                        };
1531                        // ListIndex gives null on an out-of-bounds index
1532                        if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap()
1533                        {
1534                            let index: usize = index_i64.try_into().unwrap();
1535                            **list_create = list_create_args.swap_remove(index);
1536                        } else {
1537                            let typ = list_create_type(list_create);
1538                            **list_create = MirScalarExpr::literal_null(typ);
1539                        }
1540                    }
1541                    // Peel one layer off of each of the earlier element types.
1542                    for t in earlier_list_create_types.iter_mut() {
1543                        if let SqlScalarType::List {
1544                            element_type,
1545                            custom_id: _,
1546                        } = t
1547                        {
1548                            **t = *element_type.clone();
1549                            // These are not the same types anymore, so remove custom_ids all the
1550                            // way down.
1551                            let mut u = &mut **t;
1552                            while let SqlScalarType::List {
1553                                element_type,
1554                                custom_id,
1555                            } = u
1556                            {
1557                                *custom_id = None;
1558                                u = &mut **element_type;
1559                            }
1560                        } else {
1561                            unreachable!("already matched below");
1562                        }
1563                    }
1564                } else {
1565                    // We can't remove this index, so we can't reduce any of the ListCreates at this
1566                    // level. So we change list_create_mut_refs to refer to all the arguments of all
1567                    // the ListCreates currently referenced by list_create_mut_refs.
1568                    list_create_mut_refs = list_create_mut_refs
1569                        .into_iter()
1570                        .flat_map(|list_create| match list_create {
1571                            MirScalarExpr::CallVariadic {
1572                                func: VariadicFunc::ListCreate { elem_type },
1573                                exprs: list_create_args,
1574                            } => {
1575                                earlier_list_create_types.push(elem_type);
1576                                list_create_args
1577                            }
1578                            // func cannot be anything else than a ListCreate
1579                            _ => unreachable!(),
1580                        })
1581                        .collect();
1582                    i += 1; // next index_expr
1583                }
1584            }
1585            // If all list indexes have been evaluated, return the reduced expression.
1586            // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
1587            if index_exprs.is_empty() {
1588                assert_eq!(list_create_mut_refs.len(), 1);
1589                list_create_to_reduce
1590            } else {
1591                let mut exprs: Vec<MirScalarExpr> = vec![list_create_to_reduce];
1592                exprs.append(&mut index_exprs);
1593                MirScalarExpr::CallVariadic {
1594                    func: VariadicFunc::ListIndex,
1595                    exprs,
1596                }
1597            }
1598        }
1599
1600        /* #endregion */
1601    }
1602
1603    /// Decompose an IsNull expression into a disjunction of
1604    /// simpler expressions.
1605    ///
1606    /// Assumes that `self` is the expression inside of an IsNull.
1607    /// Returns `Some(expressions)` if the outer IsNull is to be
1608    /// replaced by some other expression. Note: if it returns
1609    /// None, it might still have mutated *self.
1610    fn decompose_is_null(&mut self) -> Option<MirScalarExpr> {
1611        // TODO: allow simplification of unmaterializable functions
1612
1613        match self {
1614            MirScalarExpr::CallUnary {
1615                func,
1616                expr: inner_expr,
1617            } => {
1618                if !func.introduces_nulls() {
1619                    if func.propagates_nulls() {
1620                        *self = inner_expr.take();
1621                        return self.decompose_is_null();
1622                    } else {
1623                        // We can simplify to `false`, because the function simply can't produce
1624                        // nulls at all. This is because
1625                        // - !propagates_nulls means that the input type of the Rust function is not
1626                        //   nullable, so the automatic null propagation won't kick in;
1627                        // - !introduces_nulls means that the output type of the Rust function is
1628                        //   not nullable, so the Rust function can't produce a null manually either.
1629                        //
1630                        // Note that we can't do this same optimization for binary and variadic
1631                        // functions. This is because for binary and variadic functions the value of
1632                        // propagates_nulls and introduces_nulls is not derived solely from the
1633                        // input/output type nullabilities, but instead depends on what the Rust
1634                        // function does. For example, list concatenation neither introduces nor
1635                        // propagates nulls, but it can produce a null:
1636                        // - It does not introduce nulls, because if both input lists are not null,
1637                        //   then it will produce a list.
1638                        // - It does not propagate nulls, because giving a null as just one of the
1639                        //   arguments returns the other argument instead of null.
1640                        // - It does produce a null if both arguments are null.
1641                        return Some(MirScalarExpr::literal_false());
1642                    }
1643                }
1644            }
1645            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1646                // (<expr1> <op> <expr2>) IS NULL can often be simplified to
1647                // (<expr1> IS NULL) OR (<expr2> IS NULL).
1648                if func.propagates_nulls() && !func.introduces_nulls() {
1649                    let expr1 = expr1.take().call_is_null();
1650                    let expr2 = expr2.take().call_is_null();
1651                    return Some(expr1.or(expr2));
1652                }
1653            }
1654            MirScalarExpr::CallVariadic { func, exprs } => {
1655                if func.propagates_nulls() && !func.introduces_nulls() {
1656                    let exprs = exprs.into_iter().map(|e| e.take().call_is_null()).collect();
1657                    return Some(MirScalarExpr::CallVariadic {
1658                        func: VariadicFunc::Or,
1659                        exprs,
1660                    });
1661                }
1662            }
1663            _ => {}
1664        }
1665
1666        None
1667    }
1668
1669    /// Flattens a chain of calls to associative variadic functions
1670    /// (For example: ORs or ANDs)
1671    pub fn flatten_associative(&mut self) {
1672        match self {
1673            MirScalarExpr::CallVariadic {
1674                exprs: outer_operands,
1675                func: outer_func,
1676            } if outer_func.is_associative() => {
1677                *outer_operands = outer_operands
1678                    .into_iter()
1679                    .flat_map(|o| {
1680                        if let MirScalarExpr::CallVariadic {
1681                            exprs: inner_operands,
1682                            func: inner_func,
1683                        } = o
1684                        {
1685                            if *inner_func == *outer_func {
1686                                mem::take(inner_operands)
1687                            } else {
1688                                vec![o.take()]
1689                            }
1690                        } else {
1691                            vec![o.take()]
1692                        }
1693                    })
1694                    .collect();
1695            }
1696            _ => {}
1697        }
1698    }
1699
1700    /* #region AND/OR canonicalization and transformations  */
1701
1702    /// Canonicalizes AND/OR, and does some straightforward simplifications
1703    fn reduce_and_canonicalize_and_or(&mut self) {
1704        // We do this until fixed point, because after undistribute_and_or calls us, it relies on
1705        // the property that self is not an 1-arg AND/OR. Just one application of our loop body
1706        // can't ensure this, because the application itself might create a 1-arg AND/OR.
1707        let mut old_self = MirScalarExpr::column(0);
1708        while old_self != *self {
1709            old_self = self.clone();
1710            match self {
1711                MirScalarExpr::CallVariadic {
1712                    func: func @ (VariadicFunc::And | VariadicFunc::Or),
1713                    exprs,
1714                } => {
1715                    // Canonically order elements so that various deduplications work better,
1716                    // e.g., in undistribute_and_or.
1717                    // Also, extract_equal_or_both_null_inner depends on the args being sorted.
1718                    exprs.sort();
1719
1720                    // x AND/OR x --> x
1721                    exprs.dedup(); // this also needs the above sorting
1722
1723                    if exprs.len() == 1 {
1724                        // AND/OR of 1 argument evaluates to that argument
1725                        *self = exprs.swap_remove(0);
1726                    } else if exprs.len() == 0 {
1727                        // AND/OR of 0 arguments evaluates to true/false
1728                        *self = func.unit_of_and_or();
1729                    } else if exprs.iter().any(|e| *e == func.zero_of_and_or()) {
1730                        // short-circuiting
1731                        *self = func.zero_of_and_or();
1732                    } else {
1733                        // a AND true --> a
1734                        // a OR false --> a
1735                        exprs.retain(|e| *e != func.unit_of_and_or());
1736                    }
1737                }
1738                _ => {}
1739            }
1740        }
1741    }
1742
1743    /// Transforms !(a && b) into !a || !b, and !(a || b) into !a && !b
1744    fn demorgans(&mut self) {
1745        if let MirScalarExpr::CallUnary {
1746            expr: inner,
1747            func: UnaryFunc::Not(func::Not),
1748        } = self
1749        {
1750            inner.flatten_associative();
1751            match &mut **inner {
1752                MirScalarExpr::CallVariadic {
1753                    func: inner_func @ (VariadicFunc::And | VariadicFunc::Or),
1754                    exprs,
1755                } => {
1756                    *inner_func = inner_func.switch_and_or();
1757                    *exprs = exprs.into_iter().map(|e| e.take().not()).collect();
1758                    *self = (*inner).take(); // Removes the outer not
1759                }
1760                _ => {}
1761            }
1762        }
1763    }
1764
1765    /// AND/OR undistribution (factoring out) to apply at each `MirScalarExpr`.
1766    ///
1767    /// This method attempts to apply one of the [distribution laws][distributivity]
1768    /// (in a direction opposite to the their name):
1769    /// ```text
1770    /// (a && b) || (a && c) --> a && (b || c)  // Undistribute-OR
1771    /// (a || b) && (a || c) --> a || (b && c)  // Undistribute-AND
1772    /// ```
1773    /// or one of their corresponding two [absorption law][absorption] special
1774    /// cases:
1775    /// ```text
1776    /// a || (a && c)  -->  a  // Absorb-OR
1777    /// a && (a || c)  -->  a  // Absorb-AND
1778    /// ```
1779    ///
1780    /// The method also works with more than 2 arguments at the top, e.g.
1781    /// ```text
1782    /// (a && b) || (a && c) || (a && d)  -->  a && (b || c || d)
1783    /// ```
1784    /// It can also factor out only a subset of the top arguments, e.g.
1785    /// ```text
1786    /// (a && b) || (a && c) || (d && e)  -->  (a && (b || c)) || (d && e)
1787    /// ```
1788    ///
1789    /// Note that sometimes there are two overlapping possibilities to factor
1790    /// out from, e.g.
1791    /// ```text
1792    /// (a && b) || (a && c) || (d && c)
1793    /// ```
1794    /// Here we can factor out `a` from from the 1. and 2. terms, or we can
1795    /// factor out `c` from the 2. and 3. terms. One of these might lead to
1796    /// more/better undistribution opportunities later, but we just pick one
1797    /// locally, because recursively trying out all of them would lead to
1798    /// exponential run time.
1799    ///
1800    /// The local heuristic is that we prefer a candidate that leads to an
1801    /// absorption, or if there is no such one then we simply pick the first. In
1802    /// case of multiple absorption candidates, it doesn't matter which one we
1803    /// pick, because applying an absorption cannot adversely effect the
1804    /// possibility of applying other absorptions.
1805    ///
1806    /// # Assumption
1807    ///
1808    /// Assumes that nested chains of AND/OR applications are flattened (this
1809    /// can be enforced with [`Self::flatten_associative`]).
1810    ///
1811    /// # Examples
1812    ///
1813    /// Absorb-OR:
1814    /// ```text
1815    /// a || (a && c) || (a && d)
1816    /// -->
1817    /// a && (true || c || d)
1818    /// -->
1819    /// a && true
1820    /// -->
1821    /// a
1822    /// ```
1823    /// Here only the first step is performed by this method. The rest is done
1824    /// by [`Self::reduce_and_canonicalize_and_or`] called after us in
1825    /// `reduce()`.
1826    ///
1827    /// [distributivity]: https://en.wikipedia.org/wiki/Distributive_property
1828    /// [absorption]: https://en.wikipedia.org/wiki/Absorption_law
1829    fn undistribute_and_or(&mut self) {
1830        // It wouldn't be strictly necessary to wrap this fn in this loop, because `reduce()` calls
1831        // us in a loop anyway. However, `reduce()` tries to do many other things, so the loop here
1832        // improves performance when there are several undistributions to apply in sequence, which
1833        // can occur in `CanonicalizeMfp` when undoing the DNF.
1834        let mut old_self = MirScalarExpr::column(0);
1835        while old_self != *self {
1836            old_self = self.clone();
1837            self.reduce_and_canonicalize_and_or(); // We don't want to deal with 1-arg AND/OR at the top
1838            if let MirScalarExpr::CallVariadic {
1839                exprs: outer_operands,
1840                func: outer_func @ (VariadicFunc::Or | VariadicFunc::And),
1841            } = self
1842            {
1843                let inner_func = outer_func.switch_and_or();
1844
1845                // Make sure that each outer operand is a call to inner_func, by wrapping in a 1-arg
1846                // call if necessary.
1847                outer_operands.iter_mut().for_each(|o| {
1848                    if !matches!(o, MirScalarExpr::CallVariadic {func: f, ..} if *f == inner_func) {
1849                        *o = MirScalarExpr::CallVariadic {
1850                            func: inner_func.clone(),
1851                            exprs: vec![o.take()],
1852                        };
1853                    }
1854                });
1855
1856                let mut inner_operands_refs: Vec<&mut Vec<MirScalarExpr>> = outer_operands
1857                    .iter_mut()
1858                    .map(|o| match o {
1859                        MirScalarExpr::CallVariadic { func: f, exprs } if *f == inner_func => exprs,
1860                        _ => unreachable!(), // the wrapping made sure that we'll get a match
1861                    })
1862                    .collect();
1863
1864                // Find inner operands to undistribute, i.e., which are in _all_ of the outer operands.
1865                let mut intersection = inner_operands_refs
1866                    .iter()
1867                    .map(|v| (*v).clone())
1868                    .reduce(|ops1, ops2| ops1.into_iter().filter(|e| ops2.contains(e)).collect())
1869                    .unwrap();
1870                intersection.sort();
1871                intersection.dedup();
1872
1873                if !intersection.is_empty() {
1874                    // Factor out the intersection from all the top-level args.
1875
1876                    // Remove the intersection from each inner operand vector.
1877                    inner_operands_refs
1878                        .iter_mut()
1879                        .for_each(|ops| (**ops).retain(|o| !intersection.contains(o)));
1880
1881                    // Simplify terms that now have only 0 or 1 args due to removing the intersection.
1882                    outer_operands
1883                        .iter_mut()
1884                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1885
1886                    // Add the intersection at the beginning
1887                    *self = MirScalarExpr::CallVariadic {
1888                        func: inner_func,
1889                        exprs: intersection.into_iter().chain_one(self.clone()).collect(),
1890                    };
1891                } else {
1892                    // If the intersection was empty, that means that there is nothing we can factor out
1893                    // from _all_ the top-level args. However, we might still find something to factor
1894                    // out from a subset of the top-level args. To find such an opportunity, we look for
1895                    // duplicates across all inner args, e.g. if we have
1896                    // `(...) OR (... AND `a` AND ...) OR (...) OR (... AND `a` AND ...)`
1897                    // then we'll find that `a` occurs in more than one top-level arg, so
1898                    // `indexes_to_undistribute` will point us to the 2. and 4. top-level args.
1899
1900                    // Create (inner_operand, index) pairs, where the index is the position in
1901                    // outer_operands
1902                    let all_inner_operands = inner_operands_refs
1903                        .iter()
1904                        .enumerate()
1905                        .flat_map(|(i, inner_vec)| inner_vec.iter().map(move |a| ((*a).clone(), i)))
1906                        .sorted()
1907                        .collect_vec();
1908
1909                    // Find inner operand expressions that occur in more than one top-level arg.
1910                    // Each inner vector in `undistribution_opportunities` will belong to one such inner
1911                    // operand expression, and it is a set of indexes pointing to top-level args where
1912                    // that inner operand occurs.
1913                    let undistribution_opportunities = all_inner_operands
1914                        .iter()
1915                        .chunk_by(|(a, _i)| a)
1916                        .into_iter()
1917                        .map(|(_a, g)| g.map(|(_a, i)| *i).sorted().dedup().collect_vec())
1918                        .filter(|g| g.len() > 1)
1919                        .collect_vec();
1920
1921                    // Choose one of the inner vectors from `undistribution_opportunities`.
1922                    let indexes_to_undistribute = undistribution_opportunities
1923                        .iter()
1924                        // Let's prefer index sets that directly lead to an absorption.
1925                        .find(|index_set| {
1926                            index_set
1927                                .iter()
1928                                .any(|i| inner_operands_refs.get(*i).unwrap().len() == 1)
1929                        })
1930                        // If we didn't find any absorption, then any index set will do.
1931                        .or_else(|| undistribution_opportunities.first())
1932                        .cloned();
1933
1934                    // In any case, undo the 1-arg wrapping that we did at the beginning.
1935                    outer_operands
1936                        .iter_mut()
1937                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1938
1939                    if let Some(indexes_to_undistribute) = indexes_to_undistribute {
1940                        // Found something to undistribute from a subset of the outer operands.
1941                        // We temporarily remove these from outer_operands, call ourselves on it, and
1942                        // then push back the result.
1943                        let mut undistribute_from = MirScalarExpr::CallVariadic {
1944                            func: outer_func.clone(),
1945                            exprs: swap_remove_multiple(outer_operands, indexes_to_undistribute),
1946                        };
1947                        // By construction, the recursive call is guaranteed to hit
1948                        // the `!intersection.is_empty()` branch.
1949                        undistribute_from.undistribute_and_or();
1950                        // Append the undistributed result to outer operands that were not included in
1951                        // indexes_to_undistribute.
1952                        outer_operands.push(undistribute_from);
1953                    }
1954                }
1955            }
1956        }
1957    }
1958
1959    /* #endregion */
1960
1961    /// Adds any columns that *must* be non-Null for `self` to be non-Null.
1962    pub fn non_null_requirements(&self, columns: &mut BTreeSet<usize>) {
1963        match self {
1964            MirScalarExpr::Column(col, _name) => {
1965                columns.insert(*col);
1966            }
1967            MirScalarExpr::Literal(..) => {}
1968            MirScalarExpr::CallUnmaterializable(_) => (),
1969            MirScalarExpr::CallUnary { func, expr } => {
1970                if func.propagates_nulls() {
1971                    expr.non_null_requirements(columns);
1972                }
1973            }
1974            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1975                if func.propagates_nulls() {
1976                    expr1.non_null_requirements(columns);
1977                    expr2.non_null_requirements(columns);
1978                }
1979            }
1980            MirScalarExpr::CallVariadic { func, exprs } => {
1981                if func.propagates_nulls() {
1982                    for expr in exprs {
1983                        expr.non_null_requirements(columns);
1984                    }
1985                }
1986            }
1987            MirScalarExpr::If { .. } => (),
1988        }
1989    }
1990
1991    pub fn typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
1992        match self {
1993            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
1994            MirScalarExpr::Literal(_, typ) => typ.clone(),
1995            MirScalarExpr::CallUnmaterializable(func) => func.output_type(),
1996            MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)),
1997            MirScalarExpr::CallBinary { expr1, expr2, func } => {
1998                func.output_type(&[expr1.typ(column_types), expr2.typ(column_types)])
1999            }
2000            MirScalarExpr::CallVariadic { exprs, func } => {
2001                func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect())
2002            }
2003            MirScalarExpr::If { cond: _, then, els } => {
2004                let then_type = then.typ(column_types);
2005                let else_type = els.typ(column_types);
2006                then_type.union(&else_type)
2007            }
2008        }
2009    }
2010
2011    pub fn repr_typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2012        match self {
2013            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
2014            MirScalarExpr::Literal(_, typ) => ReprColumnType::from(typ),
2015            MirScalarExpr::CallUnmaterializable(func) => ReprColumnType::from(&func.output_type()),
2016            MirScalarExpr::CallUnary { expr, func } => ReprColumnType::from(
2017                &func.output_type(SqlColumnType::from_repr(&expr.repr_typ(column_types))),
2018            ),
2019            MirScalarExpr::CallBinary { expr1, expr2, func } => {
2020                ReprColumnType::from(&func.output_type(&[
2021                    SqlColumnType::from_repr(&expr1.repr_typ(column_types)),
2022                    SqlColumnType::from_repr(&expr2.repr_typ(column_types)),
2023                ]))
2024            }
2025            MirScalarExpr::CallVariadic { exprs, func } => ReprColumnType::from(
2026                &func.output_type(
2027                    exprs
2028                        .iter()
2029                        .map(|e| SqlColumnType::from_repr(&e.repr_typ(column_types)))
2030                        .collect(),
2031                ),
2032            ),
2033            MirScalarExpr::If { cond: _, then, els } => {
2034                let then_type = then.repr_typ(column_types);
2035                let else_type = els.repr_typ(column_types);
2036                then_type.union(&else_type).unwrap()
2037            }
2038        }
2039    }
2040
2041    pub fn eval<'a>(
2042        &'a self,
2043        datums: &[Datum<'a>],
2044        temp_storage: &'a RowArena,
2045    ) -> Result<Datum<'a>, EvalError> {
2046        match self {
2047            MirScalarExpr::Column(index, _name) => Ok(datums[*index]),
2048            MirScalarExpr::Literal(res, _column_type) => match res {
2049                Ok(row) => Ok(row.unpack_first()),
2050                Err(e) => Err(e.clone()),
2051            },
2052            // Unmaterializable functions must be transformed away before
2053            // evaluation. Their purpose is as a placeholder for data that is
2054            // not known at plan time but can be inlined before runtime.
2055            MirScalarExpr::CallUnmaterializable(x) => Err(EvalError::Internal(
2056                format!("cannot evaluate unmaterializable function: {:?}", x).into(),
2057            )),
2058            MirScalarExpr::CallUnary { func, expr } => func.eval(datums, temp_storage, expr),
2059            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2060                func.eval(datums, temp_storage, &[expr1, expr2])
2061            }
2062            MirScalarExpr::CallVariadic { func, exprs } => func.eval(datums, temp_storage, exprs),
2063            MirScalarExpr::If { cond, then, els } => match cond.eval(datums, temp_storage)? {
2064                Datum::True => then.eval(datums, temp_storage),
2065                Datum::False | Datum::Null => els.eval(datums, temp_storage),
2066                d => Err(EvalError::Internal(
2067                    format!("if condition evaluated to non-boolean datum: {:?}", d).into(),
2068                )),
2069            },
2070        }
2071    }
2072
2073    /// True iff the expression contains
2074    /// `UnmaterializableFunc::MzNow`.
2075    pub fn contains_temporal(&self) -> bool {
2076        let mut contains = false;
2077        self.visit_pre(|e| {
2078            if let MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2079                contains = true;
2080            }
2081        });
2082        contains
2083    }
2084
2085    /// True iff the expression contains an `UnmaterializableFunc`.
2086    pub fn contains_unmaterializable(&self) -> bool {
2087        let mut contains = false;
2088        self.visit_pre(|e| {
2089            if let MirScalarExpr::CallUnmaterializable(_) = e {
2090                contains = true;
2091            }
2092        });
2093        contains
2094    }
2095
2096    /// True iff the expression contains an `UnmaterializableFunc` that is not in the `exceptions`
2097    /// list.
2098    pub fn contains_unmaterializable_except(&self, exceptions: &[UnmaterializableFunc]) -> bool {
2099        let mut contains = false;
2100        self.visit_pre(|e| match e {
2101            MirScalarExpr::CallUnmaterializable(f) if !exceptions.contains(f) => contains = true,
2102            _ => (),
2103        });
2104        contains
2105    }
2106
2107    /// True iff the expression contains a `Column`.
2108    pub fn contains_column(&self) -> bool {
2109        let mut contains = false;
2110        self.visit_pre(|e| {
2111            if let MirScalarExpr::Column(_col, _name) = e {
2112                contains = true;
2113            }
2114        });
2115        contains
2116    }
2117
2118    /// True iff the expression contains a `Dummy`.
2119    pub fn contains_dummy(&self) -> bool {
2120        let mut contains = false;
2121        self.visit_pre(|e| {
2122            if let MirScalarExpr::Literal(row, _) = e {
2123                if let Ok(row) = row {
2124                    contains |= row.iter().any(|d| d.contains_dummy());
2125                }
2126            }
2127        });
2128        contains
2129    }
2130
2131    /// The size of the expression as a tree.
2132    pub fn size(&self) -> usize {
2133        let mut size = 0;
2134        self.visit_pre(&mut |_: &MirScalarExpr| {
2135            size += 1;
2136        });
2137        size
2138    }
2139}
2140
2141impl MirScalarExpr {
2142    /// True iff evaluation could possibly error on non-error input `Datum`.
2143    pub fn could_error(&self) -> bool {
2144        match self {
2145            MirScalarExpr::Column(_col, _name) => false,
2146            MirScalarExpr::Literal(row, ..) => row.is_err(),
2147            MirScalarExpr::CallUnmaterializable(_) => true,
2148            MirScalarExpr::CallUnary { func, expr } => func.could_error() || expr.could_error(),
2149            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2150                func.could_error() || expr1.could_error() || expr2.could_error()
2151            }
2152            MirScalarExpr::CallVariadic { func, exprs } => {
2153                func.could_error() || exprs.iter().any(|e| e.could_error())
2154            }
2155            MirScalarExpr::If { cond, then, els } => {
2156                cond.could_error() || then.could_error() || els.could_error()
2157            }
2158        }
2159    }
2160}
2161
2162impl VisitChildren<Self> for MirScalarExpr {
2163    fn visit_children<F>(&self, mut f: F)
2164    where
2165        F: FnMut(&Self),
2166    {
2167        use MirScalarExpr::*;
2168        match self {
2169            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2170            CallUnary { expr, .. } => {
2171                f(expr);
2172            }
2173            CallBinary { expr1, expr2, .. } => {
2174                f(expr1);
2175                f(expr2);
2176            }
2177            CallVariadic { exprs, .. } => {
2178                for expr in exprs {
2179                    f(expr);
2180                }
2181            }
2182            If { cond, then, els } => {
2183                f(cond);
2184                f(then);
2185                f(els);
2186            }
2187        }
2188    }
2189
2190    fn visit_mut_children<F>(&mut self, mut f: F)
2191    where
2192        F: FnMut(&mut Self),
2193    {
2194        use MirScalarExpr::*;
2195        match self {
2196            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2197            CallUnary { expr, .. } => {
2198                f(expr);
2199            }
2200            CallBinary { expr1, expr2, .. } => {
2201                f(expr1);
2202                f(expr2);
2203            }
2204            CallVariadic { exprs, .. } => {
2205                for expr in exprs {
2206                    f(expr);
2207                }
2208            }
2209            If { cond, then, els } => {
2210                f(cond);
2211                f(then);
2212                f(els);
2213            }
2214        }
2215    }
2216
2217    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2218    where
2219        F: FnMut(&Self) -> Result<(), E>,
2220        E: From<RecursionLimitError>,
2221    {
2222        use MirScalarExpr::*;
2223        match self {
2224            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2225            CallUnary { expr, .. } => {
2226                f(expr)?;
2227            }
2228            CallBinary { expr1, expr2, .. } => {
2229                f(expr1)?;
2230                f(expr2)?;
2231            }
2232            CallVariadic { exprs, .. } => {
2233                for expr in exprs {
2234                    f(expr)?;
2235                }
2236            }
2237            If { cond, then, els } => {
2238                f(cond)?;
2239                f(then)?;
2240                f(els)?;
2241            }
2242        }
2243        Ok(())
2244    }
2245
2246    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2247    where
2248        F: FnMut(&mut Self) -> Result<(), E>,
2249        E: From<RecursionLimitError>,
2250    {
2251        use MirScalarExpr::*;
2252        match self {
2253            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2254            CallUnary { expr, .. } => {
2255                f(expr)?;
2256            }
2257            CallBinary { expr1, expr2, .. } => {
2258                f(expr1)?;
2259                f(expr2)?;
2260            }
2261            CallVariadic { exprs, .. } => {
2262                for expr in exprs {
2263                    f(expr)?;
2264                }
2265            }
2266            If { cond, then, els } => {
2267                f(cond)?;
2268                f(then)?;
2269                f(els)?;
2270            }
2271        }
2272        Ok(())
2273    }
2274}
2275
2276impl MirScalarExpr {
2277    /// Iterates through references to child expressions.
2278    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2279        let mut first = None;
2280        let mut second = None;
2281        let mut third = None;
2282        let mut variadic = None;
2283
2284        use MirScalarExpr::*;
2285        match self {
2286            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2287            CallUnary { expr, .. } => {
2288                first = Some(&**expr);
2289            }
2290            CallBinary { expr1, expr2, .. } => {
2291                first = Some(&**expr1);
2292                second = Some(&**expr2);
2293            }
2294            CallVariadic { exprs, .. } => {
2295                variadic = Some(exprs);
2296            }
2297            If { cond, then, els } => {
2298                first = Some(&**cond);
2299                second = Some(&**then);
2300                third = Some(&**els);
2301            }
2302        }
2303
2304        first
2305            .into_iter()
2306            .chain(second)
2307            .chain(third)
2308            .chain(variadic.into_iter().flatten())
2309    }
2310
2311    /// Iterates through mutable references to child expressions.
2312    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2313        let mut first = None;
2314        let mut second = None;
2315        let mut third = None;
2316        let mut variadic = None;
2317
2318        use MirScalarExpr::*;
2319        match self {
2320            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2321            CallUnary { expr, .. } => {
2322                first = Some(&mut **expr);
2323            }
2324            CallBinary { expr1, expr2, .. } => {
2325                first = Some(&mut **expr1);
2326                second = Some(&mut **expr2);
2327            }
2328            CallVariadic { exprs, .. } => {
2329                variadic = Some(exprs);
2330            }
2331            If { cond, then, els } => {
2332                first = Some(&mut **cond);
2333                second = Some(&mut **then);
2334                third = Some(&mut **els);
2335            }
2336        }
2337
2338        first
2339            .into_iter()
2340            .chain(second)
2341            .chain(third)
2342            .chain(variadic.into_iter().flatten())
2343    }
2344
2345    /// Visits all subexpressions in DFS preorder.
2346    pub fn visit_pre<F>(&self, mut f: F)
2347    where
2348        F: FnMut(&Self),
2349    {
2350        let mut worklist = vec![self];
2351        while let Some(e) = worklist.pop() {
2352            f(e);
2353            worklist.extend(e.children().rev());
2354        }
2355    }
2356
2357    /// Iterative pre-order visitor.
2358    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2359        let mut worklist = vec![self];
2360        while let Some(expr) = worklist.pop() {
2361            f(expr);
2362            worklist.extend(expr.children_mut().rev());
2363        }
2364    }
2365}
2366
2367/// Filter characteristics that are used for ordering join inputs.
2368/// This can be created for a `Vec<MirScalarExpr>`, which represents an AND of predicates.
2369///
2370/// The fields are ordered based on heuristic assumptions about their typical selectivity, so that
2371/// Ord gives the right ordering for join inputs. Bigger is better, i.e., will tend to come earlier
2372/// than other inputs.
2373#[derive(
2374    Eq,
2375    PartialEq,
2376    Ord,
2377    PartialOrd,
2378    Debug,
2379    Clone,
2380    Serialize,
2381    Deserialize,
2382    Hash,
2383    MzReflect
2384)]
2385pub struct FilterCharacteristics {
2386    // `<expr> = <literal>` appears in the filter.
2387    // Excludes cases where NOT appears anywhere above the literal equality.
2388    literal_equality: bool,
2389    // (Assuming a random string of lower-case characters, `LIKE 'a%'` has a selectivity of 1/26.)
2390    like: bool,
2391    is_null: bool,
2392    // Number of Vec elements that involve inequality predicates. (A BETWEEN is represented as two
2393    // inequality predicates.)
2394    // Excludes cases where NOT appears around the literal inequality.
2395    // Note that for inequality predicates, some databases assume 1/3 selectivity in the absence of
2396    // concrete statistics.
2397    literal_inequality: usize,
2398    /// Any filter, except ones involving `IS NOT NULL`, because those are too common.
2399    /// Can be true by itself, or any other field being true can also make this true.
2400    /// `NOT LIKE` is only in this category.
2401    /// `!=` is only in this category.
2402    /// `NOT (a = b)` is turned into `!=` by `reduce` before us!
2403    any_filter: bool,
2404}
2405
2406impl BitOrAssign for FilterCharacteristics {
2407    fn bitor_assign(&mut self, rhs: Self) {
2408        self.literal_equality |= rhs.literal_equality;
2409        self.like |= rhs.like;
2410        self.is_null |= rhs.is_null;
2411        self.literal_inequality += rhs.literal_inequality;
2412        self.any_filter |= rhs.any_filter;
2413    }
2414}
2415
2416impl FilterCharacteristics {
2417    pub fn none() -> FilterCharacteristics {
2418        FilterCharacteristics {
2419            literal_equality: false,
2420            like: false,
2421            is_null: false,
2422            literal_inequality: 0,
2423            any_filter: false,
2424        }
2425    }
2426
2427    pub fn explain(&self) -> String {
2428        let mut e = "".to_owned();
2429        if self.literal_equality {
2430            e.push_str("e");
2431        }
2432        if self.like {
2433            e.push_str("l");
2434        }
2435        if self.is_null {
2436            e.push_str("n");
2437        }
2438        for _ in 0..self.literal_inequality {
2439            e.push_str("i");
2440        }
2441        if self.any_filter {
2442            e.push_str("f");
2443        }
2444        e
2445    }
2446
2447    pub fn filter_characteristics(
2448        filters: &Vec<MirScalarExpr>,
2449    ) -> Result<FilterCharacteristics, RecursionLimitError> {
2450        let mut literal_equality = false;
2451        let mut like = false;
2452        let mut is_null = false;
2453        let mut literal_inequality = 0;
2454        let mut any_filter = false;
2455        filters.iter().try_for_each(|f| {
2456            let mut literal_inequality_in_current_filter = false;
2457            let mut is_not_null_in_current_filter = false;
2458            f.visit_pre_with_context(
2459                false,
2460                &mut |not_in_parent_chain, expr| {
2461                    not_in_parent_chain
2462                        || matches!(
2463                            expr,
2464                            MirScalarExpr::CallUnary {
2465                                func: UnaryFunc::Not(func::Not),
2466                                ..
2467                            }
2468                        )
2469                },
2470                &mut |not_in_parent_chain, expr| {
2471                    if !not_in_parent_chain {
2472                        if expr.any_expr_eq_literal().is_some() {
2473                            literal_equality = true;
2474                        }
2475                        if expr.any_expr_ineq_literal() {
2476                            literal_inequality_in_current_filter = true;
2477                        }
2478                        if matches!(
2479                            expr,
2480                            MirScalarExpr::CallUnary {
2481                                func: UnaryFunc::IsLikeMatch(_),
2482                                ..
2483                            }
2484                        ) {
2485                            like = true;
2486                        }
2487                    };
2488                    if matches!(
2489                        expr,
2490                        MirScalarExpr::CallUnary {
2491                            func: UnaryFunc::IsNull(crate::func::IsNull),
2492                            ..
2493                        }
2494                    ) {
2495                        if *not_in_parent_chain {
2496                            is_not_null_in_current_filter = true;
2497                        } else {
2498                            is_null = true;
2499                        }
2500                    }
2501                },
2502            )?;
2503            if literal_inequality_in_current_filter {
2504                literal_inequality += 1;
2505            }
2506            if !is_not_null_in_current_filter {
2507                // We want to ignore `IS NOT NULL` for `any_filter`.
2508                any_filter = true;
2509            }
2510            Ok(())
2511        })?;
2512        Ok(FilterCharacteristics {
2513            literal_equality,
2514            like,
2515            is_null,
2516            literal_inequality,
2517            any_filter,
2518        })
2519    }
2520
2521    pub fn add_literal_equality(&mut self) {
2522        self.literal_equality = true;
2523    }
2524
2525    pub fn worst_case_scaling_factor(&self) -> f64 {
2526        let mut factor = 1.0;
2527
2528        if self.literal_equality {
2529            factor *= 0.1;
2530        }
2531
2532        if self.is_null {
2533            factor *= 0.1;
2534        }
2535
2536        if self.literal_inequality >= 2 {
2537            factor *= 0.25;
2538        } else if self.literal_inequality == 1 {
2539            factor *= 0.33;
2540        }
2541
2542        // catch various negated filters, treat them pessimistically
2543        if !(self.literal_equality || self.is_null || self.literal_inequality > 0)
2544            && self.any_filter
2545        {
2546            factor *= 0.9;
2547        }
2548
2549        factor
2550    }
2551}
2552
2553#[derive(
2554    Arbitrary,
2555    Ord,
2556    PartialOrd,
2557    Copy,
2558    Clone,
2559    Debug,
2560    Eq,
2561    PartialEq,
2562    Serialize,
2563    Deserialize,
2564    Hash,
2565    MzReflect
2566)]
2567pub enum DomainLimit {
2568    None,
2569    Inclusive(i64),
2570    Exclusive(i64),
2571}
2572
2573impl RustType<ProtoDomainLimit> for DomainLimit {
2574    fn into_proto(&self) -> ProtoDomainLimit {
2575        use proto_domain_limit::Kind::*;
2576        let kind = match self {
2577            DomainLimit::None => None(()),
2578            DomainLimit::Inclusive(v) => Inclusive(*v),
2579            DomainLimit::Exclusive(v) => Exclusive(*v),
2580        };
2581        ProtoDomainLimit { kind: Some(kind) }
2582    }
2583
2584    fn from_proto(proto: ProtoDomainLimit) -> Result<Self, TryFromProtoError> {
2585        use proto_domain_limit::Kind::*;
2586        if let Some(kind) = proto.kind {
2587            match kind {
2588                None(()) => Ok(DomainLimit::None),
2589                Inclusive(v) => Ok(DomainLimit::Inclusive(v)),
2590                Exclusive(v) => Ok(DomainLimit::Exclusive(v)),
2591            }
2592        } else {
2593            Err(TryFromProtoError::missing_field("ProtoDomainLimit::kind"))
2594        }
2595    }
2596}
2597
2598#[derive(
2599    Arbitrary,
2600    Ord,
2601    PartialOrd,
2602    Clone,
2603    Debug,
2604    Eq,
2605    PartialEq,
2606    Serialize,
2607    Deserialize,
2608    Hash,
2609    MzReflect
2610)]
2611pub enum EvalError {
2612    CharacterNotValidForEncoding(i32),
2613    CharacterTooLargeForEncoding(i32),
2614    DateBinOutOfRange(Box<str>),
2615    DivisionByZero,
2616    Unsupported {
2617        feature: Box<str>,
2618        discussion_no: Option<usize>,
2619    },
2620    FloatOverflow,
2621    FloatUnderflow,
2622    NumericFieldOverflow,
2623    Float32OutOfRange(Box<str>),
2624    Float64OutOfRange(Box<str>),
2625    Int16OutOfRange(Box<str>),
2626    Int32OutOfRange(Box<str>),
2627    Int64OutOfRange(Box<str>),
2628    UInt16OutOfRange(Box<str>),
2629    UInt32OutOfRange(Box<str>),
2630    UInt64OutOfRange(Box<str>),
2631    MzTimestampOutOfRange(Box<str>),
2632    MzTimestampStepOverflow,
2633    OidOutOfRange(Box<str>),
2634    IntervalOutOfRange(Box<str>),
2635    TimestampCannotBeNan,
2636    TimestampOutOfRange,
2637    DateOutOfRange,
2638    CharOutOfRange,
2639    IndexOutOfRange {
2640        provided: i32,
2641        // The last valid index position, i.e. `v.len() - 1`
2642        valid_end: i32,
2643    },
2644    InvalidBase64Equals,
2645    InvalidBase64Symbol(char),
2646    InvalidBase64EndSequence,
2647    InvalidTimezone(Box<str>),
2648    InvalidTimezoneInterval,
2649    InvalidTimezoneConversion,
2650    InvalidIanaTimezoneId(Box<str>),
2651    InvalidLayer {
2652        max_layer: usize,
2653        val: i64,
2654    },
2655    InvalidArray(InvalidArrayError),
2656    InvalidEncodingName(Box<str>),
2657    InvalidHashAlgorithm(Box<str>),
2658    InvalidByteSequence {
2659        byte_sequence: Box<str>,
2660        encoding_name: Box<str>,
2661    },
2662    InvalidJsonbCast {
2663        from: Box<str>,
2664        to: Box<str>,
2665    },
2666    InvalidRegex(Box<str>),
2667    InvalidRegexFlag(char),
2668    InvalidParameterValue(Box<str>),
2669    InvalidDatePart(Box<str>),
2670    KeyCannotBeNull,
2671    NegSqrt,
2672    NegLimit,
2673    NullCharacterNotPermitted,
2674    UnknownUnits(Box<str>),
2675    UnsupportedUnits(Box<str>, Box<str>),
2676    UnterminatedLikeEscapeSequence,
2677    Parse(ParseError),
2678    ParseHex(ParseHexError),
2679    Internal(Box<str>),
2680    InfinityOutOfDomain(Box<str>),
2681    NegativeOutOfDomain(Box<str>),
2682    ZeroOutOfDomain(Box<str>),
2683    OutOfDomain(DomainLimit, DomainLimit, Box<str>),
2684    ComplexOutOfRange(Box<str>),
2685    MultipleRowsFromSubquery,
2686    Undefined(Box<str>),
2687    LikePatternTooLong,
2688    LikeEscapeTooLong,
2689    StringValueTooLong {
2690        target_type: Box<str>,
2691        length: usize,
2692    },
2693    MultidimensionalArrayRemovalNotSupported,
2694    IncompatibleArrayDimensions {
2695        dims: Option<(usize, usize)>,
2696    },
2697    TypeFromOid(Box<str>),
2698    InvalidRange(InvalidRangeError),
2699    InvalidRoleId(Box<str>),
2700    InvalidPrivileges(Box<str>),
2701    LetRecLimitExceeded(Box<str>),
2702    MultiDimensionalArraySearch,
2703    MustNotBeNull(Box<str>),
2704    InvalidIdentifier {
2705        ident: Box<str>,
2706        detail: Option<Box<str>>,
2707    },
2708    ArrayFillWrongArraySubscripts,
2709    // TODO: propagate this check more widely throughout the expr crate
2710    MaxArraySizeExceeded(usize),
2711    DateDiffOverflow {
2712        unit: Box<str>,
2713        a: Box<str>,
2714        b: Box<str>,
2715    },
2716    // The error for ErrorIfNull; this should not be used in other contexts as a generic error
2717    // printer.
2718    IfNullError(Box<str>),
2719    LengthTooLarge,
2720    AclArrayNullElement,
2721    MzAclArrayNullElement,
2722    PrettyError(Box<str>),
2723}
2724
2725impl fmt::Display for EvalError {
2726    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2727        match self {
2728            EvalError::CharacterNotValidForEncoding(v) => {
2729                write!(f, "requested character not valid for encoding: {v}")
2730            }
2731            EvalError::CharacterTooLargeForEncoding(v) => {
2732                write!(f, "requested character too large for encoding: {v}")
2733            }
2734            EvalError::DateBinOutOfRange(message) => f.write_str(message),
2735            EvalError::DivisionByZero => f.write_str("division by zero"),
2736            EvalError::Unsupported {
2737                feature,
2738                discussion_no,
2739            } => {
2740                write!(f, "{} not yet supported", feature)?;
2741                if let Some(discussion_no) = discussion_no {
2742                    write!(
2743                        f,
2744                        ", see https://github.com/MaterializeInc/materialize/discussions/{} for more details",
2745                        discussion_no
2746                    )?;
2747                }
2748                Ok(())
2749            }
2750            EvalError::FloatOverflow => f.write_str("value out of range: overflow"),
2751            EvalError::FloatUnderflow => f.write_str("value out of range: underflow"),
2752            EvalError::NumericFieldOverflow => f.write_str("numeric field overflow"),
2753            EvalError::Float32OutOfRange(val) => write!(f, "{} real out of range", val.quoted()),
2754            EvalError::Float64OutOfRange(val) => {
2755                write!(f, "{} double precision out of range", val.quoted())
2756            }
2757            EvalError::Int16OutOfRange(val) => write!(f, "{} smallint out of range", val.quoted()),
2758            EvalError::Int32OutOfRange(val) => write!(f, "{} integer out of range", val.quoted()),
2759            EvalError::Int64OutOfRange(val) => write!(f, "{} bigint out of range", val.quoted()),
2760            EvalError::UInt16OutOfRange(val) => write!(f, "{} uint2 out of range", val.quoted()),
2761            EvalError::UInt32OutOfRange(val) => write!(f, "{} uint4 out of range", val.quoted()),
2762            EvalError::UInt64OutOfRange(val) => write!(f, "{} uint8 out of range", val.quoted()),
2763            EvalError::MzTimestampOutOfRange(val) => {
2764                write!(f, "{} mz_timestamp out of range", val.quoted())
2765            }
2766            EvalError::MzTimestampStepOverflow => f.write_str("step mz_timestamp overflow"),
2767            EvalError::OidOutOfRange(val) => write!(f, "{} OID out of range", val.quoted()),
2768            EvalError::IntervalOutOfRange(val) => {
2769                write!(f, "{} interval out of range", val.quoted())
2770            }
2771            EvalError::TimestampCannotBeNan => f.write_str("timestamp cannot be NaN"),
2772            EvalError::TimestampOutOfRange => f.write_str("timestamp out of range"),
2773            EvalError::DateOutOfRange => f.write_str("date out of range"),
2774            EvalError::CharOutOfRange => f.write_str("\"char\" out of range"),
2775            EvalError::IndexOutOfRange {
2776                provided,
2777                valid_end,
2778            } => write!(f, "index {provided} out of valid range, 0..{valid_end}",),
2779            EvalError::InvalidBase64Equals => {
2780                f.write_str("unexpected \"=\" while decoding base64 sequence")
2781            }
2782            EvalError::InvalidBase64Symbol(c) => write!(
2783                f,
2784                "invalid symbol \"{}\" found while decoding base64 sequence",
2785                c.escape_default()
2786            ),
2787            EvalError::InvalidBase64EndSequence => f.write_str("invalid base64 end sequence"),
2788            EvalError::InvalidJsonbCast { from, to } => {
2789                write!(f, "cannot cast jsonb {} to type {}", from, to)
2790            }
2791            EvalError::InvalidTimezone(tz) => write!(f, "invalid time zone '{}'", tz),
2792            EvalError::InvalidTimezoneInterval => {
2793                f.write_str("timezone interval must not contain months or years")
2794            }
2795            EvalError::InvalidTimezoneConversion => f.write_str("invalid timezone conversion"),
2796            EvalError::InvalidIanaTimezoneId(tz) => {
2797                write!(f, "invalid IANA Time Zone Database identifier: '{}'", tz)
2798            }
2799            EvalError::InvalidLayer { max_layer, val } => write!(
2800                f,
2801                "invalid layer: {}; must use value within [1, {}]",
2802                val, max_layer
2803            ),
2804            EvalError::InvalidArray(e) => e.fmt(f),
2805            EvalError::InvalidEncodingName(name) => write!(f, "invalid encoding name '{}'", name),
2806            EvalError::InvalidHashAlgorithm(alg) => write!(f, "invalid hash algorithm '{}'", alg),
2807            EvalError::InvalidByteSequence {
2808                byte_sequence,
2809                encoding_name,
2810            } => write!(
2811                f,
2812                "invalid byte sequence '{}' for encoding '{}'",
2813                byte_sequence, encoding_name
2814            ),
2815            EvalError::InvalidDatePart(part) => write!(f, "invalid datepart {}", part.quoted()),
2816            EvalError::KeyCannotBeNull => f.write_str("key cannot be null"),
2817            EvalError::NegSqrt => f.write_str("cannot take square root of a negative number"),
2818            EvalError::NegLimit => f.write_str("LIMIT must not be negative"),
2819            EvalError::NullCharacterNotPermitted => f.write_str("null character not permitted"),
2820            EvalError::InvalidRegex(e) => write!(f, "invalid regular expression: {}", e),
2821            EvalError::InvalidRegexFlag(c) => write!(f, "invalid regular expression flag: {}", c),
2822            EvalError::InvalidParameterValue(s) => f.write_str(s),
2823            EvalError::UnknownUnits(units) => write!(f, "unit '{}' not recognized", units),
2824            EvalError::UnsupportedUnits(units, typ) => {
2825                write!(f, "unit '{}' not supported for type {}", units, typ)
2826            }
2827            EvalError::UnterminatedLikeEscapeSequence => {
2828                f.write_str("unterminated escape sequence in LIKE")
2829            }
2830            EvalError::Parse(e) => e.fmt(f),
2831            EvalError::PrettyError(e) => e.fmt(f),
2832            EvalError::ParseHex(e) => e.fmt(f),
2833            EvalError::Internal(s) => write!(f, "internal error: {}", s),
2834            EvalError::InfinityOutOfDomain(s) => {
2835                write!(f, "function {} is only defined for finite arguments", s)
2836            }
2837            EvalError::NegativeOutOfDomain(s) => {
2838                write!(f, "function {} is not defined for negative numbers", s)
2839            }
2840            EvalError::ZeroOutOfDomain(s) => {
2841                write!(f, "function {} is not defined for zero", s)
2842            }
2843            EvalError::OutOfDomain(lower, upper, s) => {
2844                use DomainLimit::*;
2845                write!(f, "function {s} is defined for numbers ")?;
2846                match (lower, upper) {
2847                    (Inclusive(n), None) => write!(f, "greater than or equal to {n}"),
2848                    (Exclusive(n), None) => write!(f, "greater than {n}"),
2849                    (None, Inclusive(n)) => write!(f, "less than or equal to {n}"),
2850                    (None, Exclusive(n)) => write!(f, "less than {n}"),
2851                    (Inclusive(lo), Inclusive(hi)) => write!(f, "between {lo} and {hi} inclusive"),
2852                    (Exclusive(lo), Exclusive(hi)) => write!(f, "between {lo} and {hi} exclusive"),
2853                    (Inclusive(lo), Exclusive(hi)) => {
2854                        write!(f, "between {lo} inclusive and {hi} exclusive")
2855                    }
2856                    (Exclusive(lo), Inclusive(hi)) => {
2857                        write!(f, "between {lo} exclusive and {hi} inclusive")
2858                    }
2859                    (None, None) => panic!("invalid domain error"),
2860                }
2861            }
2862            EvalError::ComplexOutOfRange(s) => {
2863                write!(f, "function {} cannot return complex numbers", s)
2864            }
2865            EvalError::MultipleRowsFromSubquery => {
2866                write!(f, "more than one record produced in subquery")
2867            }
2868            EvalError::Undefined(s) => {
2869                write!(f, "{} is undefined", s)
2870            }
2871            EvalError::LikePatternTooLong => {
2872                write!(f, "LIKE pattern exceeds maximum length")
2873            }
2874            EvalError::LikeEscapeTooLong => {
2875                write!(f, "invalid escape string")
2876            }
2877            EvalError::StringValueTooLong {
2878                target_type,
2879                length,
2880            } => {
2881                write!(f, "value too long for type {}({})", target_type, length)
2882            }
2883            EvalError::MultidimensionalArrayRemovalNotSupported => {
2884                write!(
2885                    f,
2886                    "removing elements from multidimensional arrays is not supported"
2887                )
2888            }
2889            EvalError::IncompatibleArrayDimensions { dims: _ } => {
2890                write!(f, "cannot concatenate incompatible arrays")
2891            }
2892            EvalError::TypeFromOid(msg) => write!(f, "{msg}"),
2893            EvalError::InvalidRange(e) => e.fmt(f),
2894            EvalError::InvalidRoleId(msg) => write!(f, "{msg}"),
2895            EvalError::InvalidPrivileges(privilege) => {
2896                write!(f, "unrecognized privilege type: {privilege}")
2897            }
2898            EvalError::LetRecLimitExceeded(max_iters) => {
2899                write!(
2900                    f,
2901                    "Recursive query exceeded the recursion limit {}. (Use RETURN AT RECURSION LIMIT to not error, but return the current state as the final result when reaching the limit.)",
2902                    max_iters
2903                )
2904            }
2905            EvalError::MultiDimensionalArraySearch => write!(
2906                f,
2907                "searching for elements in multidimensional arrays is not supported"
2908            ),
2909            EvalError::MustNotBeNull(v) => write!(f, "{v} must not be null"),
2910            EvalError::InvalidIdentifier { ident, .. } => {
2911                write!(f, "string is not a valid identifier: {}", ident.quoted())
2912            }
2913            EvalError::ArrayFillWrongArraySubscripts => {
2914                f.write_str("wrong number of array subscripts")
2915            }
2916            EvalError::MaxArraySizeExceeded(max_size) => {
2917                write!(
2918                    f,
2919                    "array size exceeds the maximum allowed ({max_size} bytes)"
2920                )
2921            }
2922            EvalError::DateDiffOverflow { unit, a, b } => {
2923                write!(f, "datediff overflow, {unit} of {a}, {b}")
2924            }
2925            EvalError::IfNullError(s) => f.write_str(s),
2926            EvalError::LengthTooLarge => write!(f, "requested length too large"),
2927            EvalError::AclArrayNullElement => write!(f, "ACL arrays must not contain null values"),
2928            EvalError::MzAclArrayNullElement => {
2929                write!(f, "MZ_ACL arrays must not contain null values")
2930            }
2931        }
2932    }
2933}
2934
2935impl EvalError {
2936    pub fn detail(&self) -> Option<String> {
2937        match self {
2938            EvalError::IncompatibleArrayDimensions { dims: None } => Some(
2939                "Arrays with differing dimensions are not compatible for concatenation.".into(),
2940            ),
2941            EvalError::IncompatibleArrayDimensions {
2942                dims: Some((a_dims, b_dims)),
2943            } => Some(format!(
2944                "Arrays of {} and {} dimensions are not compatible for concatenation.",
2945                a_dims, b_dims
2946            )),
2947            EvalError::InvalidIdentifier { detail, .. } => detail.as_deref().map(Into::into),
2948            EvalError::ArrayFillWrongArraySubscripts => {
2949                Some("Low bound array has different size than dimensions array.".into())
2950            }
2951            _ => None,
2952        }
2953    }
2954
2955    pub fn hint(&self) -> Option<String> {
2956        match self {
2957            EvalError::InvalidBase64EndSequence => Some(
2958                "Input data is missing padding, is truncated, or is otherwise corrupted.".into(),
2959            ),
2960            EvalError::LikeEscapeTooLong => {
2961                Some("Escape string must be empty or one character.".into())
2962            }
2963            EvalError::MzTimestampOutOfRange(_) => Some(
2964                "Integer, numeric, and text casts to mz_timestamp must be in the form of whole \
2965                milliseconds since the Unix epoch. Values with fractional parts cannot be \
2966                converted to mz_timestamp."
2967                    .into(),
2968            ),
2969            _ => None,
2970        }
2971    }
2972}
2973
2974impl std::error::Error for EvalError {}
2975
2976impl From<ParseError> for EvalError {
2977    fn from(e: ParseError) -> EvalError {
2978        EvalError::Parse(e)
2979    }
2980}
2981
2982impl From<ParseHexError> for EvalError {
2983    fn from(e: ParseHexError) -> EvalError {
2984        EvalError::ParseHex(e)
2985    }
2986}
2987
2988impl From<InvalidArrayError> for EvalError {
2989    fn from(e: InvalidArrayError) -> EvalError {
2990        EvalError::InvalidArray(e)
2991    }
2992}
2993
2994impl From<RegexCompilationError> for EvalError {
2995    fn from(e: RegexCompilationError) -> EvalError {
2996        EvalError::InvalidRegex(e.to_string().into())
2997    }
2998}
2999
3000impl From<TypeFromOidError> for EvalError {
3001    fn from(e: TypeFromOidError) -> EvalError {
3002        EvalError::TypeFromOid(e.to_string().into())
3003    }
3004}
3005
3006impl From<DateError> for EvalError {
3007    fn from(e: DateError) -> EvalError {
3008        match e {
3009            DateError::OutOfRange => EvalError::DateOutOfRange,
3010        }
3011    }
3012}
3013
3014impl From<TimestampError> for EvalError {
3015    fn from(e: TimestampError) -> EvalError {
3016        match e {
3017            TimestampError::OutOfRange => EvalError::TimestampOutOfRange,
3018        }
3019    }
3020}
3021
3022impl From<InvalidRangeError> for EvalError {
3023    fn from(e: InvalidRangeError) -> EvalError {
3024        EvalError::InvalidRange(e)
3025    }
3026}
3027
3028impl RustType<ProtoEvalError> for EvalError {
3029    fn into_proto(&self) -> ProtoEvalError {
3030        use proto_eval_error::Kind::*;
3031        use proto_eval_error::*;
3032        let kind = match self {
3033            EvalError::CharacterNotValidForEncoding(v) => CharacterNotValidForEncoding(*v),
3034            EvalError::CharacterTooLargeForEncoding(v) => CharacterTooLargeForEncoding(*v),
3035            EvalError::DateBinOutOfRange(v) => DateBinOutOfRange(v.into_proto()),
3036            EvalError::DivisionByZero => DivisionByZero(()),
3037            EvalError::Unsupported {
3038                feature,
3039                discussion_no,
3040            } => Unsupported(ProtoUnsupported {
3041                feature: feature.into_proto(),
3042                discussion_no: discussion_no.into_proto(),
3043            }),
3044            EvalError::FloatOverflow => FloatOverflow(()),
3045            EvalError::FloatUnderflow => FloatUnderflow(()),
3046            EvalError::NumericFieldOverflow => NumericFieldOverflow(()),
3047            EvalError::Float32OutOfRange(val) => Float32OutOfRange(ProtoValueOutOfRange {
3048                value: val.to_string(),
3049            }),
3050            EvalError::Float64OutOfRange(val) => Float64OutOfRange(ProtoValueOutOfRange {
3051                value: val.to_string(),
3052            }),
3053            EvalError::Int16OutOfRange(val) => Int16OutOfRange(ProtoValueOutOfRange {
3054                value: val.to_string(),
3055            }),
3056            EvalError::Int32OutOfRange(val) => Int32OutOfRange(ProtoValueOutOfRange {
3057                value: val.to_string(),
3058            }),
3059            EvalError::Int64OutOfRange(val) => Int64OutOfRange(ProtoValueOutOfRange {
3060                value: val.to_string(),
3061            }),
3062            EvalError::UInt16OutOfRange(val) => Uint16OutOfRange(ProtoValueOutOfRange {
3063                value: val.to_string(),
3064            }),
3065            EvalError::UInt32OutOfRange(val) => Uint32OutOfRange(ProtoValueOutOfRange {
3066                value: val.to_string(),
3067            }),
3068            EvalError::UInt64OutOfRange(val) => Uint64OutOfRange(ProtoValueOutOfRange {
3069                value: val.to_string(),
3070            }),
3071            EvalError::MzTimestampOutOfRange(val) => MzTimestampOutOfRange(ProtoValueOutOfRange {
3072                value: val.to_string(),
3073            }),
3074            EvalError::MzTimestampStepOverflow => MzTimestampStepOverflow(()),
3075            EvalError::OidOutOfRange(val) => OidOutOfRange(ProtoValueOutOfRange {
3076                value: val.to_string(),
3077            }),
3078            EvalError::IntervalOutOfRange(val) => IntervalOutOfRange(ProtoValueOutOfRange {
3079                value: val.to_string(),
3080            }),
3081            EvalError::TimestampCannotBeNan => TimestampCannotBeNan(()),
3082            EvalError::TimestampOutOfRange => TimestampOutOfRange(()),
3083            EvalError::DateOutOfRange => DateOutOfRange(()),
3084            EvalError::CharOutOfRange => CharOutOfRange(()),
3085            EvalError::IndexOutOfRange {
3086                provided,
3087                valid_end,
3088            } => IndexOutOfRange(ProtoIndexOutOfRange {
3089                provided: *provided,
3090                valid_end: *valid_end,
3091            }),
3092            EvalError::InvalidBase64Equals => InvalidBase64Equals(()),
3093            EvalError::InvalidBase64Symbol(sym) => InvalidBase64Symbol(sym.into_proto()),
3094            EvalError::InvalidBase64EndSequence => InvalidBase64EndSequence(()),
3095            EvalError::InvalidTimezone(tz) => InvalidTimezone(tz.into_proto()),
3096            EvalError::InvalidTimezoneInterval => InvalidTimezoneInterval(()),
3097            EvalError::InvalidTimezoneConversion => InvalidTimezoneConversion(()),
3098            EvalError::InvalidLayer { max_layer, val } => InvalidLayer(ProtoInvalidLayer {
3099                max_layer: max_layer.into_proto(),
3100                val: *val,
3101            }),
3102            EvalError::InvalidArray(error) => InvalidArray(error.into_proto()),
3103            EvalError::InvalidEncodingName(v) => InvalidEncodingName(v.into_proto()),
3104            EvalError::InvalidHashAlgorithm(v) => InvalidHashAlgorithm(v.into_proto()),
3105            EvalError::InvalidByteSequence {
3106                byte_sequence,
3107                encoding_name,
3108            } => InvalidByteSequence(ProtoInvalidByteSequence {
3109                byte_sequence: byte_sequence.into_proto(),
3110                encoding_name: encoding_name.into_proto(),
3111            }),
3112            EvalError::InvalidJsonbCast { from, to } => InvalidJsonbCast(ProtoInvalidJsonbCast {
3113                from: from.into_proto(),
3114                to: to.into_proto(),
3115            }),
3116            EvalError::InvalidRegex(v) => InvalidRegex(v.into_proto()),
3117            EvalError::InvalidRegexFlag(v) => InvalidRegexFlag(v.into_proto()),
3118            EvalError::InvalidParameterValue(v) => InvalidParameterValue(v.into_proto()),
3119            EvalError::InvalidDatePart(part) => InvalidDatePart(part.into_proto()),
3120            EvalError::KeyCannotBeNull => KeyCannotBeNull(()),
3121            EvalError::NegSqrt => NegSqrt(()),
3122            EvalError::NegLimit => NegLimit(()),
3123            EvalError::NullCharacterNotPermitted => NullCharacterNotPermitted(()),
3124            EvalError::UnknownUnits(v) => UnknownUnits(v.into_proto()),
3125            EvalError::UnsupportedUnits(units, typ) => UnsupportedUnits(ProtoUnsupportedUnits {
3126                units: units.into_proto(),
3127                typ: typ.into_proto(),
3128            }),
3129            EvalError::UnterminatedLikeEscapeSequence => UnterminatedLikeEscapeSequence(()),
3130            EvalError::Parse(error) => Parse(error.into_proto()),
3131            EvalError::PrettyError(error) => PrettyError(error.into_proto()),
3132            EvalError::ParseHex(error) => ParseHex(error.into_proto()),
3133            EvalError::Internal(v) => Internal(v.into_proto()),
3134            EvalError::InfinityOutOfDomain(v) => InfinityOutOfDomain(v.into_proto()),
3135            EvalError::NegativeOutOfDomain(v) => NegativeOutOfDomain(v.into_proto()),
3136            EvalError::ZeroOutOfDomain(v) => ZeroOutOfDomain(v.into_proto()),
3137            EvalError::OutOfDomain(lower, upper, id) => OutOfDomain(ProtoOutOfDomain {
3138                lower: Some(lower.into_proto()),
3139                upper: Some(upper.into_proto()),
3140                id: id.into_proto(),
3141            }),
3142            EvalError::ComplexOutOfRange(v) => ComplexOutOfRange(v.into_proto()),
3143            EvalError::MultipleRowsFromSubquery => MultipleRowsFromSubquery(()),
3144            EvalError::Undefined(v) => Undefined(v.into_proto()),
3145            EvalError::LikePatternTooLong => LikePatternTooLong(()),
3146            EvalError::LikeEscapeTooLong => LikeEscapeTooLong(()),
3147            EvalError::StringValueTooLong {
3148                target_type,
3149                length,
3150            } => StringValueTooLong(ProtoStringValueTooLong {
3151                target_type: target_type.into_proto(),
3152                length: length.into_proto(),
3153            }),
3154            EvalError::MultidimensionalArrayRemovalNotSupported => {
3155                MultidimensionalArrayRemovalNotSupported(())
3156            }
3157            EvalError::IncompatibleArrayDimensions { dims } => {
3158                IncompatibleArrayDimensions(ProtoIncompatibleArrayDimensions {
3159                    dims: dims.into_proto(),
3160                })
3161            }
3162            EvalError::TypeFromOid(v) => TypeFromOid(v.into_proto()),
3163            EvalError::InvalidRange(error) => InvalidRange(error.into_proto()),
3164            EvalError::InvalidRoleId(v) => InvalidRoleId(v.into_proto()),
3165            EvalError::InvalidPrivileges(v) => InvalidPrivileges(v.into_proto()),
3166            EvalError::LetRecLimitExceeded(v) => WmrRecursionLimitExceeded(v.into_proto()),
3167            EvalError::MultiDimensionalArraySearch => MultiDimensionalArraySearch(()),
3168            EvalError::MustNotBeNull(v) => MustNotBeNull(v.into_proto()),
3169            EvalError::InvalidIdentifier { ident, detail } => {
3170                InvalidIdentifier(ProtoInvalidIdentifier {
3171                    ident: ident.into_proto(),
3172                    detail: detail.into_proto(),
3173                })
3174            }
3175            EvalError::ArrayFillWrongArraySubscripts => ArrayFillWrongArraySubscripts(()),
3176            EvalError::MaxArraySizeExceeded(max_size) => {
3177                MaxArraySizeExceeded(u64::cast_from(*max_size))
3178            }
3179            EvalError::DateDiffOverflow { unit, a, b } => DateDiffOverflow(ProtoDateDiffOverflow {
3180                unit: unit.into_proto(),
3181                a: a.into_proto(),
3182                b: b.into_proto(),
3183            }),
3184            EvalError::IfNullError(s) => IfNullError(s.into_proto()),
3185            EvalError::LengthTooLarge => LengthTooLarge(()),
3186            EvalError::AclArrayNullElement => AclArrayNullElement(()),
3187            EvalError::MzAclArrayNullElement => MzAclArrayNullElement(()),
3188            EvalError::InvalidIanaTimezoneId(s) => InvalidIanaTimezoneId(s.into_proto()),
3189        };
3190        ProtoEvalError { kind: Some(kind) }
3191    }
3192
3193    fn from_proto(proto: ProtoEvalError) -> Result<Self, TryFromProtoError> {
3194        use proto_eval_error::Kind::*;
3195        match proto.kind {
3196            Some(kind) => match kind {
3197                CharacterNotValidForEncoding(v) => Ok(EvalError::CharacterNotValidForEncoding(v)),
3198                CharacterTooLargeForEncoding(v) => Ok(EvalError::CharacterTooLargeForEncoding(v)),
3199                DateBinOutOfRange(v) => Ok(EvalError::DateBinOutOfRange(v.into())),
3200                DivisionByZero(()) => Ok(EvalError::DivisionByZero),
3201                Unsupported(v) => Ok(EvalError::Unsupported {
3202                    feature: v.feature.into(),
3203                    discussion_no: v.discussion_no.into_rust()?,
3204                }),
3205                FloatOverflow(()) => Ok(EvalError::FloatOverflow),
3206                FloatUnderflow(()) => Ok(EvalError::FloatUnderflow),
3207                NumericFieldOverflow(()) => Ok(EvalError::NumericFieldOverflow),
3208                Float32OutOfRange(val) => Ok(EvalError::Float32OutOfRange(val.value.into())),
3209                Float64OutOfRange(val) => Ok(EvalError::Float64OutOfRange(val.value.into())),
3210                Int16OutOfRange(val) => Ok(EvalError::Int16OutOfRange(val.value.into())),
3211                Int32OutOfRange(val) => Ok(EvalError::Int32OutOfRange(val.value.into())),
3212                Int64OutOfRange(val) => Ok(EvalError::Int64OutOfRange(val.value.into())),
3213                Uint16OutOfRange(val) => Ok(EvalError::UInt16OutOfRange(val.value.into())),
3214                Uint32OutOfRange(val) => Ok(EvalError::UInt32OutOfRange(val.value.into())),
3215                Uint64OutOfRange(val) => Ok(EvalError::UInt64OutOfRange(val.value.into())),
3216                MzTimestampOutOfRange(val) => {
3217                    Ok(EvalError::MzTimestampOutOfRange(val.value.into()))
3218                }
3219                MzTimestampStepOverflow(()) => Ok(EvalError::MzTimestampStepOverflow),
3220                OidOutOfRange(val) => Ok(EvalError::OidOutOfRange(val.value.into())),
3221                IntervalOutOfRange(val) => Ok(EvalError::IntervalOutOfRange(val.value.into())),
3222                TimestampCannotBeNan(()) => Ok(EvalError::TimestampCannotBeNan),
3223                TimestampOutOfRange(()) => Ok(EvalError::TimestampOutOfRange),
3224                DateOutOfRange(()) => Ok(EvalError::DateOutOfRange),
3225                CharOutOfRange(()) => Ok(EvalError::CharOutOfRange),
3226                IndexOutOfRange(v) => Ok(EvalError::IndexOutOfRange {
3227                    provided: v.provided,
3228                    valid_end: v.valid_end,
3229                }),
3230                InvalidBase64Equals(()) => Ok(EvalError::InvalidBase64Equals),
3231                InvalidBase64Symbol(v) => char::from_proto(v).map(EvalError::InvalidBase64Symbol),
3232                InvalidBase64EndSequence(()) => Ok(EvalError::InvalidBase64EndSequence),
3233                InvalidTimezone(v) => Ok(EvalError::InvalidTimezone(v.into())),
3234                InvalidTimezoneInterval(()) => Ok(EvalError::InvalidTimezoneInterval),
3235                InvalidTimezoneConversion(()) => Ok(EvalError::InvalidTimezoneConversion),
3236                InvalidLayer(v) => Ok(EvalError::InvalidLayer {
3237                    max_layer: usize::from_proto(v.max_layer)?,
3238                    val: v.val,
3239                }),
3240                InvalidArray(error) => Ok(EvalError::InvalidArray(error.into_rust()?)),
3241                InvalidEncodingName(v) => Ok(EvalError::InvalidEncodingName(v.into())),
3242                InvalidHashAlgorithm(v) => Ok(EvalError::InvalidHashAlgorithm(v.into())),
3243                InvalidByteSequence(v) => Ok(EvalError::InvalidByteSequence {
3244                    byte_sequence: v.byte_sequence.into(),
3245                    encoding_name: v.encoding_name.into(),
3246                }),
3247                InvalidJsonbCast(v) => Ok(EvalError::InvalidJsonbCast {
3248                    from: v.from.into(),
3249                    to: v.to.into(),
3250                }),
3251                InvalidRegex(v) => Ok(EvalError::InvalidRegex(v.into())),
3252                InvalidRegexFlag(v) => Ok(EvalError::InvalidRegexFlag(char::from_proto(v)?)),
3253                InvalidParameterValue(v) => Ok(EvalError::InvalidParameterValue(v.into())),
3254                InvalidDatePart(part) => Ok(EvalError::InvalidDatePart(part.into())),
3255                KeyCannotBeNull(()) => Ok(EvalError::KeyCannotBeNull),
3256                NegSqrt(()) => Ok(EvalError::NegSqrt),
3257                NegLimit(()) => Ok(EvalError::NegLimit),
3258                NullCharacterNotPermitted(()) => Ok(EvalError::NullCharacterNotPermitted),
3259                UnknownUnits(v) => Ok(EvalError::UnknownUnits(v.into())),
3260                UnsupportedUnits(v) => {
3261                    Ok(EvalError::UnsupportedUnits(v.units.into(), v.typ.into()))
3262                }
3263                UnterminatedLikeEscapeSequence(()) => Ok(EvalError::UnterminatedLikeEscapeSequence),
3264                Parse(error) => Ok(EvalError::Parse(error.into_rust()?)),
3265                ParseHex(error) => Ok(EvalError::ParseHex(error.into_rust()?)),
3266                Internal(v) => Ok(EvalError::Internal(v.into())),
3267                InfinityOutOfDomain(v) => Ok(EvalError::InfinityOutOfDomain(v.into())),
3268                NegativeOutOfDomain(v) => Ok(EvalError::NegativeOutOfDomain(v.into())),
3269                ZeroOutOfDomain(v) => Ok(EvalError::ZeroOutOfDomain(v.into())),
3270                OutOfDomain(v) => Ok(EvalError::OutOfDomain(
3271                    v.lower.into_rust_if_some("ProtoDomainLimit::lower")?,
3272                    v.upper.into_rust_if_some("ProtoDomainLimit::upper")?,
3273                    v.id.into(),
3274                )),
3275                ComplexOutOfRange(v) => Ok(EvalError::ComplexOutOfRange(v.into())),
3276                MultipleRowsFromSubquery(()) => Ok(EvalError::MultipleRowsFromSubquery),
3277                Undefined(v) => Ok(EvalError::Undefined(v.into())),
3278                LikePatternTooLong(()) => Ok(EvalError::LikePatternTooLong),
3279                LikeEscapeTooLong(()) => Ok(EvalError::LikeEscapeTooLong),
3280                StringValueTooLong(v) => Ok(EvalError::StringValueTooLong {
3281                    target_type: v.target_type.into(),
3282                    length: usize::from_proto(v.length)?,
3283                }),
3284                MultidimensionalArrayRemovalNotSupported(()) => {
3285                    Ok(EvalError::MultidimensionalArrayRemovalNotSupported)
3286                }
3287                IncompatibleArrayDimensions(v) => Ok(EvalError::IncompatibleArrayDimensions {
3288                    dims: v.dims.into_rust()?,
3289                }),
3290                TypeFromOid(v) => Ok(EvalError::TypeFromOid(v.into())),
3291                InvalidRange(e) => Ok(EvalError::InvalidRange(e.into_rust()?)),
3292                InvalidRoleId(v) => Ok(EvalError::InvalidRoleId(v.into())),
3293                InvalidPrivileges(v) => Ok(EvalError::InvalidPrivileges(v.into())),
3294                WmrRecursionLimitExceeded(v) => Ok(EvalError::LetRecLimitExceeded(v.into())),
3295                MultiDimensionalArraySearch(()) => Ok(EvalError::MultiDimensionalArraySearch),
3296                MustNotBeNull(v) => Ok(EvalError::MustNotBeNull(v.into())),
3297                InvalidIdentifier(v) => Ok(EvalError::InvalidIdentifier {
3298                    ident: v.ident.into(),
3299                    detail: v.detail.into_rust()?,
3300                }),
3301                ArrayFillWrongArraySubscripts(()) => Ok(EvalError::ArrayFillWrongArraySubscripts),
3302                MaxArraySizeExceeded(max_size) => {
3303                    Ok(EvalError::MaxArraySizeExceeded(usize::cast_from(max_size)))
3304                }
3305                DateDiffOverflow(v) => Ok(EvalError::DateDiffOverflow {
3306                    unit: v.unit.into(),
3307                    a: v.a.into(),
3308                    b: v.b.into(),
3309                }),
3310                IfNullError(v) => Ok(EvalError::IfNullError(v.into())),
3311                LengthTooLarge(()) => Ok(EvalError::LengthTooLarge),
3312                AclArrayNullElement(()) => Ok(EvalError::AclArrayNullElement),
3313                MzAclArrayNullElement(()) => Ok(EvalError::MzAclArrayNullElement),
3314                InvalidIanaTimezoneId(s) => Ok(EvalError::InvalidIanaTimezoneId(s.into())),
3315                PrettyError(s) => Ok(EvalError::PrettyError(s.into())),
3316            },
3317            None => Err(TryFromProtoError::missing_field("ProtoEvalError::kind")),
3318        }
3319    }
3320}
3321
3322impl RustType<ProtoDims> for (usize, usize) {
3323    fn into_proto(&self) -> ProtoDims {
3324        ProtoDims {
3325            f0: self.0.into_proto(),
3326            f1: self.1.into_proto(),
3327        }
3328    }
3329
3330    fn from_proto(proto: ProtoDims) -> Result<Self, TryFromProtoError> {
3331        Ok((proto.f0.into_rust()?, proto.f1.into_rust()?))
3332    }
3333}
3334
3335#[cfg(test)]
3336mod tests {
3337    use super::*;
3338
3339    #[mz_ore::test]
3340    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
3341    fn test_reduce() {
3342        let relation_type = vec![
3343            SqlScalarType::Int64.nullable(true),
3344            SqlScalarType::Int64.nullable(true),
3345            SqlScalarType::Int64.nullable(false),
3346        ];
3347        let col = MirScalarExpr::column;
3348        let err = |e| MirScalarExpr::literal(Err(e), SqlScalarType::Int64);
3349        let lit = |i| MirScalarExpr::literal_ok(Datum::Int64(i), SqlScalarType::Int64);
3350        let null = || MirScalarExpr::literal_null(SqlScalarType::Int64);
3351
3352        struct TestCase {
3353            input: MirScalarExpr,
3354            output: MirScalarExpr,
3355        }
3356
3357        let test_cases = vec![
3358            TestCase {
3359                input: MirScalarExpr::CallVariadic {
3360                    func: VariadicFunc::Coalesce,
3361                    exprs: vec![lit(1)],
3362                },
3363                output: lit(1),
3364            },
3365            TestCase {
3366                input: MirScalarExpr::CallVariadic {
3367                    func: VariadicFunc::Coalesce,
3368                    exprs: vec![lit(1), lit(2)],
3369                },
3370                output: lit(1),
3371            },
3372            TestCase {
3373                input: MirScalarExpr::CallVariadic {
3374                    func: VariadicFunc::Coalesce,
3375                    exprs: vec![null(), lit(2), null()],
3376                },
3377                output: lit(2),
3378            },
3379            TestCase {
3380                input: MirScalarExpr::CallVariadic {
3381                    func: VariadicFunc::Coalesce,
3382                    exprs: vec![null(), col(0), null(), col(1), lit(2), lit(3)],
3383                },
3384                output: MirScalarExpr::CallVariadic {
3385                    func: VariadicFunc::Coalesce,
3386                    exprs: vec![col(0), col(1), lit(2)],
3387                },
3388            },
3389            TestCase {
3390                input: MirScalarExpr::CallVariadic {
3391                    func: VariadicFunc::Coalesce,
3392                    exprs: vec![col(0), col(2), col(1)],
3393                },
3394                output: MirScalarExpr::CallVariadic {
3395                    func: VariadicFunc::Coalesce,
3396                    exprs: vec![col(0), col(2)],
3397                },
3398            },
3399            TestCase {
3400                input: MirScalarExpr::CallVariadic {
3401                    func: VariadicFunc::Coalesce,
3402                    exprs: vec![lit(1), err(EvalError::DivisionByZero)],
3403                },
3404                output: lit(1),
3405            },
3406            TestCase {
3407                input: MirScalarExpr::CallVariadic {
3408                    func: VariadicFunc::Coalesce,
3409                    exprs: vec![
3410                        null(),
3411                        err(EvalError::DivisionByZero),
3412                        err(EvalError::NumericFieldOverflow),
3413                    ],
3414                },
3415                output: err(EvalError::DivisionByZero),
3416            },
3417        ];
3418
3419        for tc in test_cases {
3420            let mut actual = tc.input.clone();
3421            actual.reduce(&relation_type);
3422            assert!(
3423                actual == tc.output,
3424                "input: {}\nactual: {}\nexpected: {}",
3425                tc.input,
3426                actual,
3427                tc.output
3428            );
3429        }
3430    }
3431}