Skip to main content

mz_expr/
scalar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::ops::BitOrAssign;
12use std::sync::Arc;
13use std::{fmt, mem};
14
15use itertools::Itertools;
16use mz_lowertest::MzReflect;
17use mz_ore::cast::CastFrom;
18use mz_ore::collections::CollectionExt;
19use mz_ore::iter::IteratorExt;
20use mz_ore::soft_assert_or_log;
21use mz_ore::stack::RecursionLimitError;
22use mz_ore::str::StrExt;
23use mz_ore::treat_as_equal::TreatAsEqual;
24use mz_ore::vec::swap_remove_multiple;
25use mz_pgrepr::TypeFromOidError;
26use mz_pgtz::timezone::TimezoneSpec;
27use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
28use mz_repr::adt::array::InvalidArrayError;
29use mz_repr::adt::date::DateError;
30use mz_repr::adt::datetime::DateTimeUnits;
31use mz_repr::adt::range::InvalidRangeError;
32use mz_repr::adt::regex::{Regex, RegexCompilationError};
33use mz_repr::adt::timestamp::TimestampError;
34use mz_repr::strconv::{ParseError, ParseHexError};
35use mz_repr::{Datum, ReprColumnType, ReprScalarType, Row, RowArena, SqlColumnType, SqlScalarType};
36
37use proptest::prelude::*;
38use proptest_derive::Arbitrary;
39use serde::{Deserialize, Serialize};
40
41use crate::scalar::func::format::DateTimeFormat;
42use crate::scalar::func::variadic::{
43    And, Coalesce, ListCreate, ListIndex, Or, RegexpMatch, RegexpReplace, RegexpSplitToArray,
44};
45use crate::scalar::func::{
46    BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, parse_timezone,
47    regexp_replace_parse_flags,
48};
49use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims;
50use crate::visit::{Visit, VisitChildren};
51
52pub mod func;
53pub mod like_pattern;
54
55include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs"));
56
57#[derive(
58    Clone,
59    PartialEq,
60    Eq,
61    PartialOrd,
62    Ord,
63    Hash,
64    Serialize,
65    Deserialize,
66    MzReflect
67)]
68pub enum MirScalarExpr {
69    /// A column of the input row
70    Column(usize, TreatAsEqual<Option<Arc<str>>>),
71    /// A literal value.
72    /// (Stored as a row, because we can't own a Datum)
73    Literal(Result<Row, EvalError>, ReprColumnType),
74    /// A call to an unmaterializable function.
75    ///
76    /// These functions cannot be evaluated by `MirScalarExpr::eval`. They must
77    /// be transformed away by a higher layer.
78    CallUnmaterializable(UnmaterializableFunc),
79    /// A function call that takes one expression as an argument.
80    CallUnary {
81        func: UnaryFunc,
82        expr: Box<MirScalarExpr>,
83    },
84    /// A function call that takes two expressions as arguments.
85    CallBinary {
86        func: BinaryFunc,
87        expr1: Box<MirScalarExpr>,
88        expr2: Box<MirScalarExpr>,
89    },
90    /// A function call that takes an arbitrary number of arguments.
91    CallVariadic {
92        func: VariadicFunc,
93        exprs: Vec<MirScalarExpr>,
94    },
95    /// Conditionally evaluated expressions.
96    ///
97    /// It is important that `then` and `els` only be evaluated if
98    /// `cond` is true or not, respectively. This is the only way
99    /// users can guard execution (other logical operator do not
100    /// short-circuit) and we need to preserve that.
101    If {
102        cond: Box<MirScalarExpr>,
103        then: Box<MirScalarExpr>,
104        els: Box<MirScalarExpr>,
105    },
106}
107
108// We need a custom Debug because we don't want to show `None` for name information.
109// Sadly, the `derivative` crate doesn't support this use case.
110impl std::fmt::Debug for MirScalarExpr {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            MirScalarExpr::Column(i, TreatAsEqual(Some(name))) => {
114                write!(f, "Column({i}, {name:?})")
115            }
116            MirScalarExpr::Column(i, TreatAsEqual(None)) => write!(f, "Column({i})"),
117            MirScalarExpr::Literal(lit, typ) => write!(f, "Literal({lit:?}, {typ:?})"),
118            MirScalarExpr::CallUnmaterializable(func) => {
119                write!(f, "CallUnmaterializable({func:?})")
120            }
121            MirScalarExpr::CallUnary { func, expr } => {
122                write!(f, "CallUnary({func:?}, {expr:?})")
123            }
124            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
125                write!(f, "CallBinary({func:?}, {expr1:?}, {expr2:?})")
126            }
127            MirScalarExpr::CallVariadic { func, exprs } => {
128                write!(f, "CallVariadic({func:?}, {exprs:?})")
129            }
130            MirScalarExpr::If { cond, then, els } => {
131                write!(f, "If({cond:?}, {then:?}, {els:?})")
132            }
133        }
134    }
135}
136
137impl MirScalarExpr {
138    pub fn columns(is: &[usize]) -> Vec<MirScalarExpr> {
139        is.iter().map(|i| MirScalarExpr::column(*i)).collect()
140    }
141
142    pub fn column(column: usize) -> Self {
143        MirScalarExpr::Column(column, TreatAsEqual(None))
144    }
145
146    pub fn named_column(column: usize, name: Arc<str>) -> Self {
147        MirScalarExpr::Column(column, TreatAsEqual(Some(name)))
148    }
149
150    pub fn literal(res: Result<Datum, EvalError>, typ: ReprScalarType) -> Self {
151        let typ = ReprColumnType {
152            scalar_type: typ,
153            nullable: matches!(res, Ok(Datum::Null)),
154        };
155        let row = res.map(|datum| Row::pack_slice(&[datum]));
156        MirScalarExpr::Literal(row, typ)
157    }
158
159    pub fn literal_ok(datum: Datum, typ: ReprScalarType) -> Self {
160        MirScalarExpr::literal(Ok(datum), typ)
161    }
162
163    /// Constructs a `MirScalarExpr::Literal` from a pre-packed `Row`
164    /// containing a single datum and a `ReprScalarType`. Nullability is
165    /// derived by inspecting the first datum in the row.
166    pub fn literal_from_single_element_row(row: Row, typ: ReprScalarType) -> Self {
167        soft_assert_or_log!(
168            row.iter().count() == 1,
169            "literal_from_row called with a Row containing {} datums",
170            row.iter().count()
171        );
172        let nullable = row.unpack_first() == Datum::Null;
173        let typ = ReprColumnType {
174            scalar_type: typ,
175            nullable,
176        };
177        MirScalarExpr::Literal(Ok(row), typ)
178    }
179
180    pub fn literal_null(typ: ReprScalarType) -> Self {
181        MirScalarExpr::literal_ok(Datum::Null, typ)
182    }
183
184    pub fn literal_false() -> Self {
185        MirScalarExpr::literal_ok(Datum::False, ReprScalarType::Bool)
186    }
187
188    pub fn literal_true() -> Self {
189        MirScalarExpr::literal_ok(Datum::True, ReprScalarType::Bool)
190    }
191
192    pub fn call_unary<U: Into<UnaryFunc>>(self, func: U) -> Self {
193        MirScalarExpr::CallUnary {
194            func: func.into(),
195            expr: Box::new(self),
196        }
197    }
198
199    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
200        MirScalarExpr::CallBinary {
201            func: func.into(),
202            expr1: Box::new(self),
203            expr2: Box::new(other),
204        }
205    }
206
207    /// Call function `func` on `exprs`.
208    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
209        MirScalarExpr::CallVariadic {
210            func: func.into(),
211            exprs,
212        }
213    }
214
215    pub fn if_then_else(self, t: Self, f: Self) -> Self {
216        MirScalarExpr::If {
217            cond: Box::new(self),
218            then: Box::new(t),
219            els: Box::new(f),
220        }
221    }
222
223    pub fn or(self, other: Self) -> Self {
224        MirScalarExpr::call_variadic(Or, vec![self, other])
225    }
226
227    pub fn and(self, other: Self) -> Self {
228        MirScalarExpr::call_variadic(And, vec![self, other])
229    }
230
231    pub fn not(self) -> Self {
232        self.call_unary(UnaryFunc::Not(func::Not))
233    }
234
235    pub fn call_is_null(self) -> Self {
236        self.call_unary(UnaryFunc::IsNull(func::IsNull))
237    }
238
239    /// Match AND or OR on self and get the args. If no match, then interpret self as if it were
240    /// wrapped in a 1-arg AND/OR.
241    pub fn and_or_args(&self, func_to_match: VariadicFunc) -> Vec<MirScalarExpr> {
242        assert!(func_to_match == Or.into() || func_to_match == And.into());
243        match self {
244            MirScalarExpr::CallVariadic { func, exprs } if *func == func_to_match => exprs.clone(),
245            _ => vec![self.clone()],
246        }
247    }
248
249    /// Try to match a literal equality involving the given expression on one side.
250    /// Return the (non-null) literal and a bool that indicates whether an inversion was needed.
251    ///
252    /// More specifically:
253    /// If `self` is an equality with a `null` literal on any side, then the match fails!
254    /// Otherwise: for a given `expr`, if `self` is `<expr> = <literal>` or `<literal> = <expr>`
255    /// then return `Some((<literal>, false))`. In addition to just trying to match `<expr>` as it
256    /// is, we also try to remove an invertible function call (such as a cast). If the match
257    /// succeeds with the inversion, then return `Some((<inverted-literal>, true))`. For more
258    /// details on the inversion, see `invert_casts_on_expr_eq_literal_inner`.
259    pub fn expr_eq_literal(&self, expr: &MirScalarExpr) -> Option<(Row, bool)> {
260        if let MirScalarExpr::CallBinary {
261            func: BinaryFunc::Eq(_),
262            expr1,
263            expr2,
264        } = self
265        {
266            if expr1.is_literal_null() || expr2.is_literal_null() {
267                return None;
268            }
269            if let Some(Ok(lit)) = expr1.as_literal_owned() {
270                return Self::expr_eq_literal_inner(expr, lit, expr1, expr2);
271            }
272            if let Some(Ok(lit)) = expr2.as_literal_owned() {
273                return Self::expr_eq_literal_inner(expr, lit, expr2, expr1);
274            }
275        }
276        None
277    }
278
279    fn expr_eq_literal_inner(
280        expr_to_match: &MirScalarExpr,
281        literal: Row,
282        literal_expr: &MirScalarExpr,
283        other_side: &MirScalarExpr,
284    ) -> Option<(Row, bool)> {
285        if other_side == expr_to_match {
286            return Some((literal, false));
287        } else {
288            // expr didn't exactly match. See if we can match it by inverse-casting.
289            let (cast_removed, inv_cast_lit) =
290                Self::invert_casts_on_expr_eq_literal_inner(other_side, literal_expr);
291            if &cast_removed == expr_to_match {
292                if let Some(Ok(inv_cast_lit_row)) = inv_cast_lit.as_literal_owned() {
293                    return Some((inv_cast_lit_row, true));
294                }
295            }
296        }
297        None
298    }
299
300    /// If `self` is `<expr> = <literal>` or `<literal> = <expr>` then
301    /// return `<expr>`. It also tries to remove a cast (or other invertible function call) from
302    /// `<expr>` before returning it, see `invert_casts_on_expr_eq_literal_inner`.
303    pub fn any_expr_eq_literal(&self) -> Option<MirScalarExpr> {
304        if let MirScalarExpr::CallBinary {
305            func: BinaryFunc::Eq(_),
306            expr1,
307            expr2,
308        } = self
309        {
310            if expr1.is_literal() {
311                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
312                return Some(expr);
313            }
314            if expr2.is_literal() {
315                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
316                return Some(expr);
317            }
318        }
319        None
320    }
321
322    /// If the given `MirScalarExpr` is a literal equality where one side is an invertible function
323    /// call, then calls the inverse function on both sides of the equality and returns the modified
324    /// version of the given `MirScalarExpr`. Otherwise, it returns the original expression.
325    /// For more details, see `invert_casts_on_expr_eq_literal_inner`.
326    pub fn invert_casts_on_expr_eq_literal(&self) -> MirScalarExpr {
327        if let MirScalarExpr::CallBinary {
328            func: BinaryFunc::Eq(_),
329            expr1,
330            expr2,
331        } = self
332        {
333            if expr1.is_literal() {
334                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
335                return literal.call_binary(expr, func::Eq);
336            }
337            if expr2.is_literal() {
338                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
339                return literal.call_binary(expr, func::Eq);
340            }
341            // Note: The above return statements should be consistent in whether they put the
342            // literal in expr1 or expr2, for the deduplication in CanonicalizeMfp to work.
343        }
344        self.clone()
345    }
346
347    /// Given an `<expr>` and a `<literal>` that were taken out from `<expr> = <literal>` or
348    /// `<literal> = <expr>`, it tries to simplify the equality by applying the inverse function of
349    /// the outermost function call of `<expr>` (if exists):
350    ///
351    /// `<literal> = func(<inner_expr>)`, where `func` is invertible
352    ///  -->
353    /// `<func^-1(literal)> = <inner_expr>`
354    /// if `func^-1(literal)` doesn't error out, and both `func` and `func^-1` preserve uniqueness.
355    ///
356    /// The return value is the `<inner_expr>` and the literal value that we get by applying the
357    /// inverse function.
358    fn invert_casts_on_expr_eq_literal_inner(
359        expr: &MirScalarExpr,
360        literal: &MirScalarExpr,
361    ) -> (MirScalarExpr, MirScalarExpr) {
362        assert!(matches!(literal, MirScalarExpr::Literal(..)));
363
364        let temp_storage = &RowArena::new();
365        let eval = |e: &MirScalarExpr| {
366            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
367        };
368
369        if let MirScalarExpr::CallUnary {
370            func,
371            expr: inner_expr,
372        } = expr
373        {
374            if let Some(inverse_func) = func.inverse() {
375                // We don't want to remove a function call that doesn't preserve uniqueness, e.g.,
376                // if `f` is a float, we don't want to inverse-cast `f::INT = 0`, because the
377                // inserted int-to-float cast wouldn't be able to invert the rounding.
378                // Also, we don't want to insert a function call that doesn't preserve
379                // uniqueness. E.g., if `a` has an integer type, we don't want to do
380                // a surprise rounding for `WHERE a = 3.14`.
381                if func.preserves_uniqueness() && inverse_func.preserves_uniqueness() {
382                    let lit_inv = eval(&MirScalarExpr::CallUnary {
383                        func: inverse_func,
384                        expr: Box::new(literal.clone()),
385                    });
386                    // The evaluation can error out, e.g., when casting a too large int32 to int16.
387                    // This case is handled by `impossible_literal_equality_because_types`.
388                    if !lit_inv.is_literal_err() {
389                        return (*inner_expr.clone(), lit_inv);
390                    }
391                }
392            }
393        }
394        (expr.clone(), literal.clone())
395    }
396
397    /// Tries to remove a cast (or other invertible function) in the same way as
398    /// `invert_casts_on_expr_eq_literal`, but if calling the inverse function fails on the literal,
399    /// then it deems the equality to be impossible. For example if `a` is a smallint column, then
400    /// it catches `a::integer = 1000000` to be an always false predicate (where the `::integer`
401    /// could have been inserted implicitly).
402    pub fn impossible_literal_equality_because_types(&self) -> bool {
403        if let MirScalarExpr::CallBinary {
404            func: BinaryFunc::Eq(_),
405            expr1,
406            expr2,
407        } = self
408        {
409            if expr1.is_literal() {
410                return Self::impossible_literal_equality_because_types_inner(expr1, expr2);
411            }
412            if expr2.is_literal() {
413                return Self::impossible_literal_equality_because_types_inner(expr2, expr1);
414            }
415        }
416        false
417    }
418
419    fn impossible_literal_equality_because_types_inner(
420        literal: &MirScalarExpr,
421        other_side: &MirScalarExpr,
422    ) -> bool {
423        assert!(matches!(literal, MirScalarExpr::Literal(..)));
424
425        let temp_storage = &RowArena::new();
426        let eval = |e: &MirScalarExpr| {
427            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
428        };
429
430        if let MirScalarExpr::CallUnary { func, .. } = other_side {
431            if let Some(inverse_func) = func.inverse() {
432                if inverse_func.preserves_uniqueness()
433                    && eval(&MirScalarExpr::CallUnary {
434                        func: inverse_func,
435                        expr: Box::new(literal.clone()),
436                    })
437                    .is_literal_err()
438                {
439                    return true;
440                }
441            }
442        }
443
444        false
445    }
446
447    /// Determines if `self` is
448    /// `<expr> < <literal>` or
449    /// `<expr> > <literal>` or
450    /// `<literal> < <expr>` or
451    /// `<literal> > <expr>` or
452    /// `<expr> <= <literal>` or
453    /// `<expr> >= <literal>` or
454    /// `<literal> <= <expr>` or
455    /// `<literal> >= <expr>`.
456    pub fn any_expr_ineq_literal(&self) -> bool {
457        match self {
458            MirScalarExpr::CallBinary {
459                func:
460                    BinaryFunc::Lt(_) | BinaryFunc::Lte(_) | BinaryFunc::Gt(_) | BinaryFunc::Gte(_),
461                expr1,
462                expr2,
463            } => expr1.is_literal() || expr2.is_literal(),
464            _ => false,
465        }
466    }
467
468    /// Rewrites column indices with their value in `permutation`.
469    ///
470    /// This method is applicable even when `permutation` is not a
471    /// strict permutation, and it only needs to have entries for
472    /// each column referenced in `self`.
473    pub fn permute(&mut self, permutation: &[usize]) {
474        self.visit_columns(|c| *c = permutation[*c]);
475    }
476
477    /// Rewrites column indices with their value in `permutation`.
478    ///
479    /// This method is applicable even when `permutation` is not a
480    /// strict permutation, and it only needs to have entries for
481    /// each column referenced in `self`.
482    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
483        self.visit_columns(|c| *c = permutation[c]);
484    }
485
486    /// Visits each column reference and applies `action` to the column.
487    ///
488    /// Useful for remapping columns, or for collecting expression support.
489    pub fn visit_columns<F>(&mut self, mut action: F)
490    where
491        F: FnMut(&mut usize),
492    {
493        self.visit_pre_mut(|e| {
494            if let MirScalarExpr::Column(col, _) = e {
495                action(col);
496            }
497        });
498    }
499
500    pub fn support(&self) -> BTreeSet<usize> {
501        let mut support = BTreeSet::new();
502        self.support_into(&mut support);
503        support
504    }
505
506    pub fn support_into(&self, support: &mut BTreeSet<usize>) {
507        self.visit_pre(|e| {
508            if let MirScalarExpr::Column(i, _) = e {
509                support.insert(*i);
510            }
511        });
512    }
513
514    pub fn take(&mut self) -> Self {
515        mem::replace(self, MirScalarExpr::literal_null(ReprScalarType::String))
516    }
517
518    /// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
519    /// Otherwise, it returns None.
520    pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
521        if let MirScalarExpr::Literal(lit, _column_type) = self {
522            Some(lit.as_ref().map(|row| row.unpack_first()))
523        } else {
524            None
525        }
526    }
527
528    /// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
529    /// literal's Datum only if the expression is a literal, and it's not a literal error.
530    pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
531        self.as_literal().map(|eval_err| eval_err.ok()).flatten()
532    }
533
534    pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
535        if let MirScalarExpr::Literal(lit, _column_type) = self {
536            Some(lit.clone())
537        } else {
538            None
539        }
540    }
541
542    pub fn as_literal_str(&self) -> Option<&str> {
543        match self.as_literal() {
544            Some(Ok(Datum::String(s))) => Some(s),
545            _ => None,
546        }
547    }
548
549    pub fn as_literal_int64(&self) -> Option<i64> {
550        match self.as_literal() {
551            Some(Ok(Datum::Int64(i))) => Some(i),
552            _ => None,
553        }
554    }
555
556    pub fn as_literal_err(&self) -> Option<&EvalError> {
557        self.as_literal().and_then(|lit| lit.err())
558    }
559
560    pub fn is_literal(&self) -> bool {
561        matches!(self, MirScalarExpr::Literal(_, _))
562    }
563
564    pub fn is_literal_true(&self) -> bool {
565        Some(Ok(Datum::True)) == self.as_literal()
566    }
567
568    pub fn is_literal_false(&self) -> bool {
569        Some(Ok(Datum::False)) == self.as_literal()
570    }
571
572    pub fn is_literal_null(&self) -> bool {
573        Some(Ok(Datum::Null)) == self.as_literal()
574    }
575
576    pub fn is_literal_ok(&self) -> bool {
577        matches!(self, MirScalarExpr::Literal(Ok(_), _typ))
578    }
579
580    pub fn is_literal_err(&self) -> bool {
581        matches!(self, MirScalarExpr::Literal(Err(_), _typ))
582    }
583
584    pub fn is_column(&self) -> bool {
585        matches!(self, MirScalarExpr::Column(_col, _name))
586    }
587
588    pub fn is_error_if_null(&self) -> bool {
589        matches!(
590            self,
591            Self::CallVariadic {
592                func: VariadicFunc::ErrorIfNull(_),
593                ..
594            }
595        )
596    }
597
598    /// If `self` expresses a temporal filter, normalize it to start with `mz_now()` and return
599    /// references.
600    ///
601    /// A temporal filter is an expression of the form `mz_now() <BINOP> <EXPR>`,
602    /// for a restricted set of `BINOP` and `EXPR` that do not themselves contain `mz_now()`.
603    /// Expressions may conform to this once their expressions are swapped.
604    ///
605    /// If the expression is not a temporal filter, it will be unchanged, and the reason for why
606    /// it's not a temporal filter is returned as a string.
607    pub fn as_mut_temporal_filter(&mut self) -> Result<(&BinaryFunc, &mut MirScalarExpr), String> {
608        if !self.contains_temporal() {
609            return Err("Does not involve mz_now()".to_string());
610        }
611        // Supported temporal predicates are exclusively binary operators.
612        if let MirScalarExpr::CallBinary { func, expr1, expr2 } = self {
613            // Attempt to put `LogicalTimestamp` in the first argument position.
614            if !expr1.contains_temporal()
615                && **expr2 == MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
616            {
617                std::mem::swap(expr1, expr2);
618                *func = match func {
619                    BinaryFunc::Eq(_) => func::Eq.into(),
620                    BinaryFunc::Lt(_) => func::Gt.into(),
621                    BinaryFunc::Lte(_) => func::Gte.into(),
622                    BinaryFunc::Gt(_) => func::Lt.into(),
623                    BinaryFunc::Gte(_) => func::Lte.into(),
624                    x => {
625                        return Err(format!("Unsupported binary temporal operation: {:?}", x));
626                    }
627                };
628            }
629
630            // Error if MLT is referenced in an unsupported position.
631            if expr2.contains_temporal()
632                || **expr1 != MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
633            {
634                return Err(format!(
635                    "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a mz_timestamp-castable expression. Expression found: {}",
636                    MirScalarExpr::CallBinary {
637                        func: func.clone(),
638                        expr1: expr1.clone(),
639                        expr2: expr2.clone()
640                    },
641                ));
642            }
643
644            Ok((&*func, expr2))
645        } else {
646            Err(format!(
647                "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a non-temporal expression of mz_timestamp-castable type. Expression found: {}",
648                self,
649            ))
650        }
651    }
652
653    #[deprecated = "Use `might_error` instead"]
654    pub fn contains_error_if_null(&self) -> bool {
655        let mut worklist = vec![self];
656        while let Some(expr) = worklist.pop() {
657            if expr.is_error_if_null() {
658                return true;
659            }
660            worklist.extend(expr.children());
661        }
662        false
663    }
664
665    pub fn contains_err(&self) -> bool {
666        let mut worklist = vec![self];
667        while let Some(expr) = worklist.pop() {
668            if expr.is_literal_err() {
669                return true;
670            }
671            worklist.extend(expr.children());
672        }
673        false
674    }
675
676    /// A very crude approximation for scalar expressions that might produce an
677    /// error.
678    ///
679    /// Currently, this is restricted only to expressions that either contain a
680    /// literal error or a [`VariadicFunc::ErrorIfNull`] call.
681    pub fn might_error(&self) -> bool {
682        let mut worklist = vec![self];
683        while let Some(expr) = worklist.pop() {
684            if expr.is_literal_err() || expr.is_error_if_null() {
685                return true;
686            }
687            worklist.extend(expr.children());
688        }
689        false
690    }
691
692    /// If self is a column, return the column index, otherwise `None`.
693    pub fn as_column(&self) -> Option<usize> {
694        if let MirScalarExpr::Column(c, _) = self {
695            Some(*c)
696        } else {
697            None
698        }
699    }
700
701    /// Reduces a complex expression where possible.
702    ///
703    /// This function uses nullability information present in `column_types`,
704    /// and the result may only continue to be a correct transformation as
705    /// long as this information continues to hold (nullability may not hold
706    /// as expressions migrate around).
707    ///
708    /// (If you'd like to not use nullability information here, then you can
709    /// tweak the nullabilities in `column_types` before passing it to this
710    /// function, see e.g. in `EquivalenceClasses::minimize`.)
711    ///
712    /// Also performs partial canonicalization on the expression.
713    ///
714    /// ```rust
715    /// use mz_expr::MirScalarExpr;
716    /// use mz_repr::{ReprColumnType, Datum, SqlScalarType};
717    ///
718    /// let expr_0 = MirScalarExpr::column(0);
719    /// let expr_t = MirScalarExpr::literal_true();
720    /// let expr_f = MirScalarExpr::literal_false();
721    ///
722    /// let mut test =
723    /// expr_t
724    ///     .clone()
725    ///     .and(expr_f.clone())
726    ///     .if_then_else(expr_0, expr_t.clone());
727    ///
728    /// let input_type = vec![ReprColumnType::from(&SqlScalarType::Int32.nullable(false))];
729    /// test.reduce(&input_type);
730    /// assert_eq!(test, expr_t);
731    /// ```
732    /// Reduce the expression to a simpler form.
733    pub fn reduce(&mut self, column_types: &[ReprColumnType]) {
734        let temp_storage = &RowArena::new();
735        let eval = |e: &MirScalarExpr| {
736            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type)
737        };
738
739        // Simplifications run in a loop until `self` no longer changes.
740        let mut old_self = MirScalarExpr::column(0);
741        while old_self != *self {
742            old_self = self.clone();
743            #[allow(deprecated)]
744            self.visit_mut_pre_post_nolimit(
745                &mut |e| {
746                    match e {
747                        MirScalarExpr::CallUnary { func, expr } => {
748                            if *func == UnaryFunc::IsNull(func::IsNull) {
749                                if !expr.typ(column_types).nullable {
750                                    *e = MirScalarExpr::literal_false();
751                                } else {
752                                    // Try to at least decompose IsNull into a disjunction
753                                    // of simpler IsNull subexpressions.
754                                    if let Some(expr) = expr.decompose_is_null() {
755                                        *e = expr
756                                    }
757                                }
758                            } else if *func == UnaryFunc::Not(func::Not) {
759                                // Push down not expressions
760                                match &mut **expr {
761                                    // Two negates cancel each other out.
762                                    MirScalarExpr::CallUnary {
763                                        expr: inner_expr,
764                                        func: UnaryFunc::Not(func::Not),
765                                    } => *e = inner_expr.take(),
766                                    // Transforms `NOT(a <op> b)` to `a negate(<op>) b`
767                                    // if a negation exists.
768                                    MirScalarExpr::CallBinary { expr1, expr2, func } => {
769                                        if let Some(negated_func) = func.negate() {
770                                            *e = MirScalarExpr::CallBinary {
771                                                expr1: Box::new(expr1.take()),
772                                                expr2: Box::new(expr2.take()),
773                                                func: negated_func,
774                                            }
775                                        }
776                                    }
777                                    MirScalarExpr::CallVariadic { .. } => {
778                                        e.demorgans();
779                                    }
780                                    _ => {}
781                                }
782                            }
783                        }
784                        _ => {}
785                    };
786                    None
787                },
788                &mut |e| match e {
789                    // Evaluate and pull up constants
790                    MirScalarExpr::Column(_, _)
791                    | MirScalarExpr::Literal(_, _)
792                    | MirScalarExpr::CallUnmaterializable(_) => (),
793                    MirScalarExpr::CallUnary { func, expr } => {
794                        if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
795                            *e = eval(e);
796                        } else if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
797                            if let MirScalarExpr::CallVariadic {
798                                func: VariadicFunc::RecordCreate(..),
799                                exprs,
800                            } = &mut **expr
801                            {
802                                *e = exprs.swap_remove(i);
803                            }
804                        }
805                    }
806                    MirScalarExpr::CallBinary { func, expr1, expr2 } => {
807                        if expr1.is_literal() && expr2.is_literal() {
808                            *e = eval(e);
809                        } else if (expr1.is_literal_null() || expr2.is_literal_null())
810                            && func.propagates_nulls()
811                        {
812                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
813                        } else if let Some(err) = expr1.as_literal_err() {
814                            *e = MirScalarExpr::literal(
815                                Err(err.clone()),
816                                e.typ(column_types).scalar_type,
817                            );
818                        } else if let Some(err) = expr2.as_literal_err() {
819                            *e = MirScalarExpr::literal(
820                                Err(err.clone()),
821                                e.typ(column_types).scalar_type,
822                            );
823                        } else if let BinaryFunc::IsLikeMatchCaseInsensitive(_) = func {
824                            if expr2.is_literal() {
825                                // We can at least precompile the regex.
826                                let pattern = expr2.as_literal_str().unwrap();
827                                *e = match like_pattern::compile(pattern, true) {
828                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
829                                        func::IsLikeMatch(matcher),
830                                    )),
831                                    Err(err) => MirScalarExpr::literal(
832                                        Err(err),
833                                        e.typ(column_types).scalar_type,
834                                    ),
835                                };
836                            }
837                        } else if let BinaryFunc::IsLikeMatchCaseSensitive(_) = func {
838                            if expr2.is_literal() {
839                                // We can at least precompile the regex.
840                                let pattern = expr2.as_literal_str().unwrap();
841                                *e = match like_pattern::compile(pattern, false) {
842                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
843                                        func::IsLikeMatch(matcher),
844                                    )),
845                                    Err(err) => MirScalarExpr::literal(
846                                        Err(err),
847                                        e.typ(column_types).scalar_type,
848                                    ),
849                                };
850                            }
851                        } else if matches!(
852                            func,
853                            BinaryFunc::IsRegexpMatchCaseSensitive(_)
854                                | BinaryFunc::IsRegexpMatchCaseInsensitive(_)
855                        ) {
856                            let case_insensitive =
857                                matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
858                            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
859                                *e = match Regex::new(
860                                    row.unpack_first().unwrap_str(),
861                                    case_insensitive,
862                                ) {
863                                    Ok(regex) => expr1.take().call_unary(UnaryFunc::IsRegexpMatch(
864                                        func::IsRegexpMatch(regex),
865                                    )),
866                                    Err(err) => MirScalarExpr::literal(
867                                        Err(err.into()),
868                                        e.typ(column_types).scalar_type,
869                                    ),
870                                };
871                            }
872                        } else if let BinaryFunc::ExtractInterval(_) = *func
873                            && expr1.is_literal()
874                        {
875                            let units = expr1.as_literal_str().unwrap();
876                            *e = match units.parse::<DateTimeUnits>() {
877                                Ok(units) => MirScalarExpr::CallUnary {
878                                    func: UnaryFunc::ExtractInterval(func::ExtractInterval(units)),
879                                    expr: Box::new(expr2.take()),
880                                },
881                                Err(_) => MirScalarExpr::literal(
882                                    Err(EvalError::UnknownUnits(units.into())),
883                                    e.typ(column_types).scalar_type,
884                                ),
885                            }
886                        } else if let BinaryFunc::ExtractTime(_) = *func
887                            && expr1.is_literal()
888                        {
889                            let units = expr1.as_literal_str().unwrap();
890                            *e = match units.parse::<DateTimeUnits>() {
891                                Ok(units) => MirScalarExpr::CallUnary {
892                                    func: UnaryFunc::ExtractTime(func::ExtractTime(units)),
893                                    expr: Box::new(expr2.take()),
894                                },
895                                Err(_) => MirScalarExpr::literal(
896                                    Err(EvalError::UnknownUnits(units.into())),
897                                    e.typ(column_types).scalar_type,
898                                ),
899                            }
900                        } else if let BinaryFunc::ExtractTimestamp(_) = *func
901                            && expr1.is_literal()
902                        {
903                            let units = expr1.as_literal_str().unwrap();
904                            *e = match units.parse::<DateTimeUnits>() {
905                                Ok(units) => MirScalarExpr::CallUnary {
906                                    func: UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(
907                                        units,
908                                    )),
909                                    expr: Box::new(expr2.take()),
910                                },
911                                Err(_) => MirScalarExpr::literal(
912                                    Err(EvalError::UnknownUnits(units.into())),
913                                    e.typ(column_types).scalar_type,
914                                ),
915                            }
916                        } else if let BinaryFunc::ExtractTimestampTz(_) = *func
917                            && expr1.is_literal()
918                        {
919                            let units = expr1.as_literal_str().unwrap();
920                            *e = match units.parse::<DateTimeUnits>() {
921                                Ok(units) => MirScalarExpr::CallUnary {
922                                    func: UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(
923                                        units,
924                                    )),
925                                    expr: Box::new(expr2.take()),
926                                },
927                                Err(_) => MirScalarExpr::literal(
928                                    Err(EvalError::UnknownUnits(units.into())),
929                                    e.typ(column_types).scalar_type,
930                                ),
931                            }
932                        } else if let BinaryFunc::ExtractDate(_) = *func
933                            && expr1.is_literal()
934                        {
935                            let units = expr1.as_literal_str().unwrap();
936                            *e = match units.parse::<DateTimeUnits>() {
937                                Ok(units) => MirScalarExpr::CallUnary {
938                                    func: UnaryFunc::ExtractDate(func::ExtractDate(units)),
939                                    expr: Box::new(expr2.take()),
940                                },
941                                Err(_) => MirScalarExpr::literal(
942                                    Err(EvalError::UnknownUnits(units.into())),
943                                    e.typ(column_types).scalar_type,
944                                ),
945                            }
946                        } else if let BinaryFunc::DatePartInterval(_) = *func
947                            && expr1.is_literal()
948                        {
949                            let units = expr1.as_literal_str().unwrap();
950                            *e = match units.parse::<DateTimeUnits>() {
951                                Ok(units) => MirScalarExpr::CallUnary {
952                                    func: UnaryFunc::DatePartInterval(func::DatePartInterval(
953                                        units,
954                                    )),
955                                    expr: Box::new(expr2.take()),
956                                },
957                                Err(_) => MirScalarExpr::literal(
958                                    Err(EvalError::UnknownUnits(units.into())),
959                                    e.typ(column_types).scalar_type,
960                                ),
961                            }
962                        } else if let BinaryFunc::DatePartTime(_) = *func
963                            && expr1.is_literal()
964                        {
965                            let units = expr1.as_literal_str().unwrap();
966                            *e = match units.parse::<DateTimeUnits>() {
967                                Ok(units) => MirScalarExpr::CallUnary {
968                                    func: UnaryFunc::DatePartTime(func::DatePartTime(units)),
969                                    expr: Box::new(expr2.take()),
970                                },
971                                Err(_) => MirScalarExpr::literal(
972                                    Err(EvalError::UnknownUnits(units.into())),
973                                    e.typ(column_types).scalar_type,
974                                ),
975                            }
976                        } else if let BinaryFunc::DatePartTimestamp(_) = *func
977                            && expr1.is_literal()
978                        {
979                            let units = expr1.as_literal_str().unwrap();
980                            *e = match units.parse::<DateTimeUnits>() {
981                                Ok(units) => MirScalarExpr::CallUnary {
982                                    func: UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(
983                                        units,
984                                    )),
985                                    expr: Box::new(expr2.take()),
986                                },
987                                Err(_) => MirScalarExpr::literal(
988                                    Err(EvalError::UnknownUnits(units.into())),
989                                    e.typ(column_types).scalar_type,
990                                ),
991                            }
992                        } else if let BinaryFunc::DatePartTimestampTz(_) = *func
993                            && expr1.is_literal()
994                        {
995                            let units = expr1.as_literal_str().unwrap();
996                            *e = match units.parse::<DateTimeUnits>() {
997                                Ok(units) => MirScalarExpr::CallUnary {
998                                    func: UnaryFunc::DatePartTimestampTz(
999                                        func::DatePartTimestampTz(units),
1000                                    ),
1001                                    expr: Box::new(expr2.take()),
1002                                },
1003                                Err(_) => MirScalarExpr::literal(
1004                                    Err(EvalError::UnknownUnits(units.into())),
1005                                    e.typ(column_types).scalar_type,
1006                                ),
1007                            }
1008                        } else if let BinaryFunc::DateTruncTimestamp(_) = *func
1009                            && expr1.is_literal()
1010                        {
1011                            let units = expr1.as_literal_str().unwrap();
1012                            *e = match units.parse::<DateTimeUnits>() {
1013                                Ok(units) => MirScalarExpr::CallUnary {
1014                                    func: UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(
1015                                        units,
1016                                    )),
1017                                    expr: Box::new(expr2.take()),
1018                                },
1019                                Err(_) => MirScalarExpr::literal(
1020                                    Err(EvalError::UnknownUnits(units.into())),
1021                                    e.typ(column_types).scalar_type,
1022                                ),
1023                            }
1024                        } else if let BinaryFunc::DateTruncTimestampTz(_) = *func
1025                            && expr1.is_literal()
1026                        {
1027                            let units = expr1.as_literal_str().unwrap();
1028                            *e = match units.parse::<DateTimeUnits>() {
1029                                Ok(units) => MirScalarExpr::CallUnary {
1030                                    func: UnaryFunc::DateTruncTimestampTz(
1031                                        func::DateTruncTimestampTz(units),
1032                                    ),
1033                                    expr: Box::new(expr2.take()),
1034                                },
1035                                Err(_) => MirScalarExpr::literal(
1036                                    Err(EvalError::UnknownUnits(units.into())),
1037                                    e.typ(column_types).scalar_type,
1038                                ),
1039                            }
1040                        } else if matches!(func, BinaryFunc::TimezoneTimestampBinary(_))
1041                            && expr1.is_literal()
1042                        {
1043                            // If the timezone argument is a literal, and we're applying the function on many rows at the same
1044                            // time we really don't want to parse it again and again, so we parse it once and embed it into the
1045                            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
1046                            let tz = expr1.as_literal_str().unwrap();
1047                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1048                                Ok(tz) => MirScalarExpr::CallUnary {
1049                                    func: UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz)),
1050                                    expr: Box::new(expr2.take()),
1051                                },
1052                                Err(err) => MirScalarExpr::literal(
1053                                    Err(err),
1054                                    e.typ(column_types).scalar_type,
1055                                ),
1056                            }
1057                        } else if matches!(func, BinaryFunc::TimezoneTimestampTzBinary(_))
1058                            && expr1.is_literal()
1059                        {
1060                            let tz = expr1.as_literal_str().unwrap();
1061                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1062                                Ok(tz) => MirScalarExpr::CallUnary {
1063                                    func: UnaryFunc::TimezoneTimestampTz(
1064                                        func::TimezoneTimestampTz(tz),
1065                                    ),
1066                                    expr: Box::new(expr2.take()),
1067                                },
1068                                Err(err) => MirScalarExpr::literal(
1069                                    Err(err),
1070                                    e.typ(column_types).scalar_type,
1071                                ),
1072                            }
1073                        } else if let BinaryFunc::ToCharTimestamp(_) = *func
1074                            && expr2.is_literal()
1075                        {
1076                            let format_str = expr2.as_literal_str().unwrap();
1077                            *e = MirScalarExpr::CallUnary {
1078                                func: UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
1079                                    format_string: format_str.to_string(),
1080                                    format: DateTimeFormat::compile(format_str),
1081                                }),
1082                                expr: Box::new(expr1.take()),
1083                            };
1084                        } else if let BinaryFunc::ToCharTimestampTz(_) = *func
1085                            && expr2.is_literal()
1086                        {
1087                            let format_str = expr2.as_literal_str().unwrap();
1088                            *e = MirScalarExpr::CallUnary {
1089                                func: UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
1090                                    format_string: format_str.to_string(),
1091                                    format: DateTimeFormat::compile(format_str),
1092                                }),
1093                                expr: Box::new(expr1.take()),
1094                            };
1095                        } else if matches!(*func, BinaryFunc::Eq(_) | BinaryFunc::NotEq(_))
1096                            && expr2 < expr1
1097                        {
1098                            // Canonically order elements so that deduplication works better.
1099                            // Also, the below `Literal([c1, c2]) = record_create(e1, e2)` matching
1100                            // relies on this canonical ordering.
1101                            mem::swap(expr1, expr2);
1102                        } else if let (
1103                            BinaryFunc::Eq(_),
1104                            MirScalarExpr::Literal(
1105                                Ok(lit_row),
1106                                ReprColumnType {
1107                                    scalar_type:
1108                                        ReprScalarType::Record {
1109                                            fields: field_types,
1110                                            ..
1111                                        },
1112                                    ..
1113                                },
1114                            ),
1115                            MirScalarExpr::CallVariadic {
1116                                func: VariadicFunc::RecordCreate(..),
1117                                exprs: rec_create_args,
1118                            },
1119                        ) = (&*func, &**expr1, &**expr2)
1120                        {
1121                            // Literal([c1, c2]) = record_create(e1, e2)
1122                            //  -->
1123                            // c1 = e1 AND c2 = e2
1124                            //
1125                            // (Records are represented as lists.)
1126                            //
1127                            // `MapFilterProject::literal_constraints` relies on this transform,
1128                            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
1129                            match lit_row.unpack_first() {
1130                                Datum::List(datum_list) => {
1131                                    *e = MirScalarExpr::call_variadic(
1132                                        And,
1133                                        datum_list
1134                                            .iter()
1135                                            .zip_eq(field_types)
1136                                            .zip_eq(rec_create_args)
1137                                            .map(|((d, typ), a)| {
1138                                                MirScalarExpr::literal_ok(
1139                                                    d,
1140                                                    typ.scalar_type.clone(),
1141                                                )
1142                                                .call_binary(a.clone(), func::Eq)
1143                                            })
1144                                            .collect(),
1145                                    );
1146                                }
1147                                _ => {}
1148                            }
1149                        } else if let (
1150                            BinaryFunc::Eq(_),
1151                            MirScalarExpr::CallVariadic {
1152                                func: VariadicFunc::RecordCreate(..),
1153                                exprs: rec_create_args1,
1154                            },
1155                            MirScalarExpr::CallVariadic {
1156                                func: VariadicFunc::RecordCreate(..),
1157                                exprs: rec_create_args2,
1158                            },
1159                        ) = (&*func, &**expr1, &**expr2)
1160                        {
1161                            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
1162                            //  -->
1163                            // a1 = b1 AND a2 = b2 AND ...
1164                            //
1165                            // This is similar to the previous reduction, but this one kicks in also
1166                            // when only some (or none) of the record fields are literals. This
1167                            // enables the discovery of literal constraints for those fields.
1168                            //
1169                            // Note that there is a similar decomposition in
1170                            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
1171                            // pipeline than the compilation of IN lists to `record_create`.
1172                            *e = MirScalarExpr::call_variadic(
1173                                And,
1174                                rec_create_args1
1175                                    .into_iter()
1176                                    .zip_eq(rec_create_args2)
1177                                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
1178                                    .collect(),
1179                            )
1180                        }
1181                    }
1182                    MirScalarExpr::CallVariadic { .. } => {
1183                        e.flatten_associative();
1184                        let (func, exprs) = match e {
1185                            MirScalarExpr::CallVariadic { func, exprs } => (func, exprs),
1186                            _ => unreachable!("`flatten_associative` shouldn't change node type"),
1187                        };
1188                        if *func == Coalesce.into() {
1189                            // If all inputs are null, output is null. This check must
1190                            // be done before `exprs.retain...` because `e.typ` requires
1191                            // > 0 `exprs` remain.
1192                            if exprs.iter().all(|expr| expr.is_literal_null()) {
1193                                *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1194                                return;
1195                            }
1196
1197                            // Remove any null values if not all values are null.
1198                            exprs.retain(|e| !e.is_literal_null());
1199
1200                            // Find the first argument that is a literal or non-nullable
1201                            // column. All arguments after it get ignored, so throw them
1202                            // away. This intentionally throws away errors that can
1203                            // never happen.
1204                            if let Some(i) = exprs
1205                                .iter()
1206                                .position(|e| e.is_literal() || !e.typ(column_types).nullable)
1207                            {
1208                                exprs.truncate(i + 1);
1209                            }
1210
1211                            // Deduplicate arguments in cases like `coalesce(#0, #0)`.
1212                            let mut prior_exprs = BTreeSet::new();
1213                            exprs.retain(|e| prior_exprs.insert(e.clone()));
1214
1215                            if exprs.len() == 1 {
1216                                // Only one argument, so the coalesce is a no-op.
1217                                *e = exprs[0].take();
1218                            }
1219                        } else if exprs.iter().all(|e| e.is_literal()) {
1220                            *e = eval(e);
1221                        } else if func.propagates_nulls()
1222                            && exprs.iter().any(|e| e.is_literal_null())
1223                        {
1224                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1225                        } else if let Some(err) = exprs.iter().find_map(|e| e.as_literal_err()) {
1226                            *e = MirScalarExpr::literal(
1227                                Err(err.clone()),
1228                                e.typ(column_types).scalar_type,
1229                            );
1230                        } else if *func == RegexpMatch.into()
1231                            && exprs[1].is_literal()
1232                            && exprs.get(2).map_or(true, |e| e.is_literal())
1233                        {
1234                            let needle = exprs[1].as_literal_str().unwrap();
1235                            let flags = match exprs.len() {
1236                                3 => exprs[2].as_literal_str().unwrap(),
1237                                _ => "",
1238                            };
1239                            *e = match func::build_regex(needle, flags) {
1240                                Ok(regex) => mem::take(exprs)
1241                                    .into_first()
1242                                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
1243                                Err(err) => MirScalarExpr::literal(
1244                                    Err(err),
1245                                    e.typ(column_types).scalar_type,
1246                                ),
1247                            };
1248                        } else if *func == RegexpReplace.into()
1249                            && exprs[1].is_literal()
1250                            && exprs.get(3).map_or(true, |e| e.is_literal())
1251                        {
1252                            let pattern = exprs[1].as_literal_str().unwrap();
1253                            let flags = exprs
1254                                .get(3)
1255                                .map_or("", |expr| expr.as_literal_str().unwrap());
1256                            let (limit, flags) = regexp_replace_parse_flags(flags);
1257
1258                            // The behavior of `regexp_replace` is that if the data is `NULL`, the
1259                            // function returns `NULL`, independently of whether the pattern or
1260                            // flags are correct. We need to check for this case and introduce an
1261                            // if-then-else on the error path to only surface the error if the first
1262                            // input is non-NULL.
1263                            *e = match func::build_regex(pattern, &flags) {
1264                                Ok(regex) => {
1265                                    let mut exprs = mem::take(exprs);
1266                                    let replacement = exprs.swap_remove(2);
1267                                    let source = exprs.swap_remove(0);
1268                                    source.call_binary(
1269                                        replacement,
1270                                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
1271                                    )
1272                                }
1273                                Err(err) => {
1274                                    let mut exprs = mem::take(exprs);
1275                                    let source = exprs.swap_remove(0);
1276                                    let scalar_type = e.typ(column_types).scalar_type;
1277                                    // We need to return `NULL` on `NULL` input, and error otherwise.
1278                                    source.call_is_null().if_then_else(
1279                                        MirScalarExpr::literal_null(scalar_type.clone()),
1280                                        MirScalarExpr::literal(Err(err), scalar_type),
1281                                    )
1282                                }
1283                            };
1284                        } else if *func == RegexpSplitToArray.into()
1285                            && exprs[1].is_literal()
1286                            && exprs.get(2).map_or(true, |e| e.is_literal())
1287                        {
1288                            let needle = exprs[1].as_literal_str().unwrap();
1289                            let flags = match exprs.len() {
1290                                3 => exprs[2].as_literal_str().unwrap(),
1291                                _ => "",
1292                            };
1293                            *e = match func::build_regex(needle, flags) {
1294                                Ok(regex) => mem::take(exprs).into_first().call_unary(
1295                                    UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(regex)),
1296                                ),
1297                                Err(err) => MirScalarExpr::literal(
1298                                    Err(err),
1299                                    e.typ(column_types).scalar_type,
1300                                ),
1301                            };
1302                        } else if *func == ListIndex.into() && is_list_create_call(&exprs[0]) {
1303                            // We are looking for ListIndex(ListCreate, literal), and eliminate
1304                            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
1305                            let ind_exprs = exprs.split_off(1);
1306                            let top_list_create = exprs.swap_remove(0);
1307                            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
1308                        } else if *func == Or.into() || *func == And.into() {
1309                            // Note: It's important that we have called `flatten_associative` above.
1310                            e.undistribute_and_or();
1311                            e.reduce_and_canonicalize_and_or();
1312                        } else if let VariadicFunc::TimezoneTimeVariadic(_) = func {
1313                            if exprs[0].is_literal() && exprs[2].is_literal_ok() {
1314                                let tz = exprs[0].as_literal_str().unwrap();
1315                                *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1316                                    Ok(tz) => MirScalarExpr::CallUnary {
1317                                        func: UnaryFunc::TimezoneTime(func::TimezoneTime {
1318                                            tz,
1319                                            wall_time: exprs[2]
1320                                                .as_literal()
1321                                                .unwrap()
1322                                                .unwrap()
1323                                                .unwrap_timestamptz()
1324                                                .naive_utc(),
1325                                        }),
1326                                        expr: Box::new(exprs[1].take()),
1327                                    },
1328                                    Err(err) => MirScalarExpr::literal(
1329                                        Err(err),
1330                                        e.typ(column_types).scalar_type,
1331                                    ),
1332                                }
1333                            }
1334                        }
1335                    }
1336                    MirScalarExpr::If { cond, then, els } => {
1337                        if let Some(literal) = cond.as_literal() {
1338                            match literal {
1339                                Ok(Datum::True) => *e = then.take(),
1340                                Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
1341                                Err(err) => {
1342                                    *e = MirScalarExpr::Literal(
1343                                        Err(err.clone()),
1344                                        then.typ(column_types)
1345                                            .union(&els.typ(column_types))
1346                                            .unwrap(),
1347                                    )
1348                                }
1349                                _ => unreachable!(),
1350                            }
1351                        } else if then == els {
1352                            *e = then.take();
1353                        } else if then.is_literal_ok()
1354                            && els.is_literal_ok()
1355                            && then.typ(column_types).scalar_type == ReprScalarType::Bool
1356                            && els.typ(column_types).scalar_type == ReprScalarType::Bool
1357                        {
1358                            match (then.as_literal(), els.as_literal()) {
1359                                // Note: NULLs from the condition should not be propagated to the result
1360                                // of the expression.
1361                                (Some(Ok(Datum::True)), _) => {
1362                                    // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
1363                                    // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
1364                                    *e = cond
1365                                        .clone()
1366                                        .call_is_null()
1367                                        .not()
1368                                        .and(cond.take())
1369                                        .or(els.take());
1370                                }
1371                                (Some(Ok(Datum::False)), _) => {
1372                                    // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
1373                                    // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
1374                                    *e = cond
1375                                        .clone()
1376                                        .not()
1377                                        .or(cond.take().call_is_null())
1378                                        .and(els.take());
1379                                }
1380                                (_, Some(Ok(Datum::True))) => {
1381                                    // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
1382                                    // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
1383                                    *e = cond
1384                                        .clone()
1385                                        .not()
1386                                        .or(cond.take().call_is_null())
1387                                        .or(then.take());
1388                                }
1389                                (_, Some(Ok(Datum::False))) => {
1390                                    // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
1391                                    // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
1392                                    *e = cond
1393                                        .clone()
1394                                        .call_is_null()
1395                                        .not()
1396                                        .and(cond.take())
1397                                        .and(then.take());
1398                                }
1399                                _ => {}
1400                            }
1401                        } else {
1402                            // Equivalent expression structure would allow us to push the `If` into the expression.
1403                            // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
1404                            //
1405                            // We have to also make sure that the expressions that will end up in
1406                            // the two `If` branches have unionable types. Otherwise, the `If` could
1407                            // not be typed by `typ`. An example where this could cause an issue is
1408                            // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
1409                            // range of input types. (In theory, we could still do the optimization
1410                            // in this case by inserting appropriate casts, but this corner case is
1411                            // not worth the complication for now.)
1412                            // See https://github.com/MaterializeInc/database-issues/issues/9182
1413                            match (&mut **then, &mut **els) {
1414                                (
1415                                    MirScalarExpr::CallUnary { func: f1, expr: e1 },
1416                                    MirScalarExpr::CallUnary { func: f2, expr: e2 },
1417                                ) if f1 == f2 && e1.typ(column_types) == e2.typ(column_types) => {
1418                                    *e = cond
1419                                        .take()
1420                                        .if_then_else(e1.take(), e2.take())
1421                                        .call_unary(f1.clone());
1422                                }
1423                                (
1424                                    MirScalarExpr::CallBinary {
1425                                        func: f1,
1426                                        expr1: e1a,
1427                                        expr2: e2a,
1428                                    },
1429                                    MirScalarExpr::CallBinary {
1430                                        func: f2,
1431                                        expr1: e1b,
1432                                        expr2: e2b,
1433                                    },
1434                                ) if f1 == f2
1435                                    && e1a == e1b
1436                                    && e2a.typ(column_types) == e2b.typ(column_types) =>
1437                                {
1438                                    *e = e1a.take().call_binary(
1439                                        cond.take().if_then_else(e2a.take(), e2b.take()),
1440                                        f1.clone(),
1441                                    );
1442                                }
1443                                (
1444                                    MirScalarExpr::CallBinary {
1445                                        func: f1,
1446                                        expr1: e1a,
1447                                        expr2: e2a,
1448                                    },
1449                                    MirScalarExpr::CallBinary {
1450                                        func: f2,
1451                                        expr1: e1b,
1452                                        expr2: e2b,
1453                                    },
1454                                ) if f1 == f2
1455                                    && e2a == e2b
1456                                    && e1a.typ(column_types) == e1b.typ(column_types) =>
1457                                {
1458                                    *e = cond
1459                                        .take()
1460                                        .if_then_else(e1a.take(), e1b.take())
1461                                        .call_binary(e2a.take(), f1.clone());
1462                                }
1463                                _ => {}
1464                            }
1465                        }
1466                    }
1467                },
1468            );
1469        }
1470
1471        /* #region `reduce_list_create_list_index_literal` and helper functions */
1472
1473        fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
1474            if let MirScalarExpr::CallVariadic {
1475                func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
1476                ..
1477            } = list_create
1478            {
1479                ReprScalarType::from(typ)
1480            } else {
1481                unreachable!()
1482            }
1483        }
1484
1485        fn is_list_create_call(expr: &MirScalarExpr) -> bool {
1486            matches!(
1487                expr,
1488                MirScalarExpr::CallVariadic {
1489                    func: VariadicFunc::ListCreate(..),
1490                    ..
1491                }
1492            )
1493        }
1494
1495        /// Partial-evaluates a list indexing with a literal directly after a list creation.
1496        ///
1497        /// Multi-dimensional lists are handled by a single call to this function, with multiple
1498        /// elements in index_exprs (of which not all need to be literals), and nested ListCreates
1499        /// in list_create_to_reduce.
1500        ///
1501        /// # Examples
1502        ///
1503        /// `LIST[f1,f2][2]` --> `f2`.
1504        ///
1505        /// A multi-dimensional list, with only some of the indexes being literals:
1506        /// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
1507        ///
1508        /// See more examples in list.slt.
1509        fn reduce_list_create_list_index_literal(
1510            mut list_create_to_reduce: MirScalarExpr,
1511            mut index_exprs: Vec<MirScalarExpr>,
1512        ) -> MirScalarExpr {
1513            // We iterate over the index_exprs and remove literals, but keep non-literals.
1514            // When we encounter a non-literal, we need to dig into the nested ListCreates:
1515            // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
1516            // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
1517            // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
1518            // that are at the current level (except those that disappeared due to
1519            // literals at earlier levels), index into them with the literal, and change each
1520            // element in `list_create_mut_refs` to the result.
1521            // We also record mut refs to all the earlier `element_type` references that we have
1522            // seen in ListCreate calls, because when we process a literal index, we need to remove
1523            // one layer of list type from all these earlier ListCreate `element_type`s.
1524            let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
1525            let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
1526            let mut i = 0;
1527            while i < index_exprs.len()
1528                && list_create_mut_refs
1529                    .iter()
1530                    .all(|lc| is_list_create_call(lc))
1531            {
1532                if index_exprs[i].is_literal_ok() {
1533                    // We can remove this index.
1534                    let removed_index = index_exprs.remove(i);
1535                    let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
1536                        Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
1537                        _ => unreachable!(), // always an Int64, see plan_index_list
1538                    };
1539                    // For each list_create referenced by list_create_mut_refs, substitute it by its
1540                    // `index`th argument (or null).
1541                    for list_create in &mut list_create_mut_refs {
1542                        let list_create_args = match list_create {
1543                            MirScalarExpr::CallVariadic {
1544                                func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
1545                                exprs,
1546                            } => exprs,
1547                            _ => unreachable!(), // func cannot be anything else than a ListCreate
1548                        };
1549                        // ListIndex gives null on an out-of-bounds index
1550                        if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap()
1551                        {
1552                            let index: usize = index_i64.try_into().unwrap();
1553                            **list_create = list_create_args.swap_remove(index);
1554                        } else {
1555                            let typ = list_create_type(list_create);
1556                            **list_create = MirScalarExpr::literal_null(typ);
1557                        }
1558                    }
1559                    // Peel one layer off of each of the earlier element types.
1560                    for t in earlier_list_create_types.iter_mut() {
1561                        if let SqlScalarType::List {
1562                            element_type,
1563                            custom_id: _,
1564                        } = t
1565                        {
1566                            **t = *element_type.clone();
1567                            // These are not the same types anymore, so remove custom_ids all the
1568                            // way down.
1569                            let mut u = &mut **t;
1570                            while let SqlScalarType::List {
1571                                element_type,
1572                                custom_id,
1573                            } = u
1574                            {
1575                                *custom_id = None;
1576                                u = &mut **element_type;
1577                            }
1578                        } else {
1579                            unreachable!("already matched below");
1580                        }
1581                    }
1582                } else {
1583                    // We can't remove this index, so we can't reduce any of the ListCreates at this
1584                    // level. So we change list_create_mut_refs to refer to all the arguments of all
1585                    // the ListCreates currently referenced by list_create_mut_refs.
1586                    list_create_mut_refs = list_create_mut_refs
1587                        .into_iter()
1588                        .flat_map(|list_create| match list_create {
1589                            MirScalarExpr::CallVariadic {
1590                                func: VariadicFunc::ListCreate(ListCreate { elem_type }),
1591                                exprs: list_create_args,
1592                            } => {
1593                                earlier_list_create_types.push(elem_type);
1594                                list_create_args
1595                            }
1596                            // func cannot be anything else than a ListCreate
1597                            _ => unreachable!(),
1598                        })
1599                        .collect();
1600                    i += 1; // next index_expr
1601                }
1602            }
1603            // If all list indexes have been evaluated, return the reduced expression.
1604            // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
1605            if index_exprs.is_empty() {
1606                assert_eq!(list_create_mut_refs.len(), 1);
1607                list_create_to_reduce
1608            } else {
1609                MirScalarExpr::call_variadic(
1610                    ListIndex,
1611                    std::iter::once(list_create_to_reduce)
1612                        .chain(index_exprs)
1613                        .collect(),
1614                )
1615            }
1616        }
1617
1618        /* #endregion */
1619    }
1620
1621    /// Decompose an IsNull expression into a disjunction of
1622    /// simpler expressions.
1623    ///
1624    /// Assumes that `self` is the expression inside of an IsNull.
1625    /// Returns `Some(expressions)` if the outer IsNull is to be
1626    /// replaced by some other expression. Note: if it returns
1627    /// None, it might still have mutated *self.
1628    fn decompose_is_null(&mut self) -> Option<MirScalarExpr> {
1629        // TODO: allow simplification of unmaterializable functions
1630
1631        match self {
1632            MirScalarExpr::CallUnary {
1633                func,
1634                expr: inner_expr,
1635            } => {
1636                if !func.introduces_nulls() {
1637                    if func.propagates_nulls() {
1638                        *self = inner_expr.take();
1639                        return self.decompose_is_null();
1640                    } else {
1641                        // We can simplify to `false`, because the function simply can't produce
1642                        // nulls at all. This is because
1643                        // - !propagates_nulls means that the input type of the Rust function is not
1644                        //   nullable, so the automatic null propagation won't kick in;
1645                        // - !introduces_nulls means that the output type of the Rust function is
1646                        //   not nullable, so the Rust function can't produce a null manually either.
1647                        //
1648                        // Note that we can't do this same optimization for binary and variadic
1649                        // functions. This is because for binary and variadic functions the value of
1650                        // propagates_nulls and introduces_nulls is not derived solely from the
1651                        // input/output type nullabilities, but instead depends on what the Rust
1652                        // function does. For example, list concatenation neither introduces nor
1653                        // propagates nulls, but it can produce a null:
1654                        // - It does not introduce nulls, because if both input lists are not null,
1655                        //   then it will produce a list.
1656                        // - It does not propagate nulls, because giving a null as just one of the
1657                        //   arguments returns the other argument instead of null.
1658                        // - It does produce a null if both arguments are null.
1659                        return Some(MirScalarExpr::literal_false());
1660                    }
1661                }
1662            }
1663            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1664                // (<expr1> <op> <expr2>) IS NULL can often be simplified to
1665                // (<expr1> IS NULL) OR (<expr2> IS NULL).
1666                if func.propagates_nulls() && !func.introduces_nulls() {
1667                    let expr1 = expr1.take().call_is_null();
1668                    let expr2 = expr2.take().call_is_null();
1669                    return Some(expr1.or(expr2));
1670                }
1671            }
1672            MirScalarExpr::CallVariadic { func, exprs } => {
1673                if func.propagates_nulls() && !func.introduces_nulls() {
1674                    let exprs = exprs.into_iter().map(|e| e.take().call_is_null()).collect();
1675                    return Some(MirScalarExpr::call_variadic(Or, exprs));
1676                }
1677            }
1678            _ => {}
1679        }
1680
1681        None
1682    }
1683
1684    /// Flattens a chain of calls to associative variadic functions
1685    /// (For example: ORs or ANDs)
1686    pub fn flatten_associative(&mut self) {
1687        match self {
1688            MirScalarExpr::CallVariadic {
1689                exprs: outer_operands,
1690                func: outer_func,
1691            } if outer_func.is_associative() => {
1692                *outer_operands = outer_operands
1693                    .into_iter()
1694                    .flat_map(|o| {
1695                        if let MirScalarExpr::CallVariadic {
1696                            exprs: inner_operands,
1697                            func: inner_func,
1698                        } = o
1699                        {
1700                            if *inner_func == *outer_func {
1701                                mem::take(inner_operands)
1702                            } else {
1703                                vec![o.take()]
1704                            }
1705                        } else {
1706                            vec![o.take()]
1707                        }
1708                    })
1709                    .collect();
1710            }
1711            _ => {}
1712        }
1713    }
1714
1715    /* #region AND/OR canonicalization and transformations  */
1716
1717    /// Canonicalizes AND/OR, and does some straightforward simplifications
1718    fn reduce_and_canonicalize_and_or(&mut self) {
1719        // We do this until fixed point, because after undistribute_and_or calls us, it relies on
1720        // the property that self is not an 1-arg AND/OR. Just one application of our loop body
1721        // can't ensure this, because the application itself might create a 1-arg AND/OR.
1722        let mut old_self = MirScalarExpr::column(0);
1723        while old_self != *self {
1724            old_self = self.clone();
1725            match self {
1726                MirScalarExpr::CallVariadic {
1727                    func: func @ (VariadicFunc::And(_) | VariadicFunc::Or(_)),
1728                    exprs,
1729                } => {
1730                    // Canonically order elements so that various deduplications work better,
1731                    // e.g., in undistribute_and_or.
1732                    // Also, extract_equal_or_both_null_inner depends on the args being sorted.
1733                    exprs.sort();
1734
1735                    // x AND/OR x --> x
1736                    exprs.dedup(); // this also needs the above sorting
1737
1738                    if exprs.len() == 1 {
1739                        // AND/OR of 1 argument evaluates to that argument
1740                        *self = exprs.swap_remove(0);
1741                    } else if exprs.len() == 0 {
1742                        // AND/OR of 0 arguments evaluates to true/false
1743                        *self = func.unit_of_and_or();
1744                    } else if exprs.iter().any(|e| *e == func.zero_of_and_or()) {
1745                        // short-circuiting
1746                        *self = func.zero_of_and_or();
1747                    } else {
1748                        // a AND true --> a
1749                        // a OR false --> a
1750                        exprs.retain(|e| *e != func.unit_of_and_or());
1751                    }
1752                }
1753                _ => {}
1754            }
1755        }
1756    }
1757
1758    /// Transforms !(a && b) into !a || !b, and !(a || b) into !a && !b
1759    fn demorgans(&mut self) {
1760        if let MirScalarExpr::CallUnary {
1761            expr: inner,
1762            func: UnaryFunc::Not(func::Not),
1763        } = self
1764        {
1765            inner.flatten_associative();
1766            match &mut **inner {
1767                MirScalarExpr::CallVariadic {
1768                    func: inner_func @ (VariadicFunc::And(_) | VariadicFunc::Or(_)),
1769                    exprs,
1770                } => {
1771                    *inner_func = inner_func.switch_and_or();
1772                    *exprs = exprs.into_iter().map(|e| e.take().not()).collect();
1773                    *self = (*inner).take(); // Removes the outer not
1774                }
1775                _ => {}
1776            }
1777        }
1778    }
1779
1780    /// AND/OR undistribution (factoring out) to apply at each `MirScalarExpr`.
1781    ///
1782    /// This method attempts to apply one of the [distribution laws][distributivity]
1783    /// (in a direction opposite to the their name):
1784    /// ```text
1785    /// (a && b) || (a && c) --> a && (b || c)  // Undistribute-OR
1786    /// (a || b) && (a || c) --> a || (b && c)  // Undistribute-AND
1787    /// ```
1788    /// or one of their corresponding two [absorption law][absorption] special
1789    /// cases:
1790    /// ```text
1791    /// a || (a && c)  -->  a  // Absorb-OR
1792    /// a && (a || c)  -->  a  // Absorb-AND
1793    /// ```
1794    ///
1795    /// The method also works with more than 2 arguments at the top, e.g.
1796    /// ```text
1797    /// (a && b) || (a && c) || (a && d)  -->  a && (b || c || d)
1798    /// ```
1799    /// It can also factor out only a subset of the top arguments, e.g.
1800    /// ```text
1801    /// (a && b) || (a && c) || (d && e)  -->  (a && (b || c)) || (d && e)
1802    /// ```
1803    ///
1804    /// Note that sometimes there are two overlapping possibilities to factor
1805    /// out from, e.g.
1806    /// ```text
1807    /// (a && b) || (a && c) || (d && c)
1808    /// ```
1809    /// Here we can factor out `a` from from the 1. and 2. terms, or we can
1810    /// factor out `c` from the 2. and 3. terms. One of these might lead to
1811    /// more/better undistribution opportunities later, but we just pick one
1812    /// locally, because recursively trying out all of them would lead to
1813    /// exponential run time.
1814    ///
1815    /// The local heuristic is that we prefer a candidate that leads to an
1816    /// absorption, or if there is no such one then we simply pick the first. In
1817    /// case of multiple absorption candidates, it doesn't matter which one we
1818    /// pick, because applying an absorption cannot adversely effect the
1819    /// possibility of applying other absorptions.
1820    ///
1821    /// # Assumption
1822    ///
1823    /// Assumes that nested chains of AND/OR applications are flattened (this
1824    /// can be enforced with [`Self::flatten_associative`]).
1825    ///
1826    /// # Examples
1827    ///
1828    /// Absorb-OR:
1829    /// ```text
1830    /// a || (a && c) || (a && d)
1831    /// -->
1832    /// a && (true || c || d)
1833    /// -->
1834    /// a && true
1835    /// -->
1836    /// a
1837    /// ```
1838    /// Here only the first step is performed by this method. The rest is done
1839    /// by [`Self::reduce_and_canonicalize_and_or`] called after us in
1840    /// `reduce()`.
1841    ///
1842    /// [distributivity]: https://en.wikipedia.org/wiki/Distributive_property
1843    /// [absorption]: https://en.wikipedia.org/wiki/Absorption_law
1844    fn undistribute_and_or(&mut self) {
1845        // It wouldn't be strictly necessary to wrap this fn in this loop, because `reduce()` calls
1846        // us in a loop anyway. However, `reduce()` tries to do many other things, so the loop here
1847        // improves performance when there are several undistributions to apply in sequence, which
1848        // can occur in `CanonicalizeMfp` when undoing the DNF.
1849        let mut old_self = MirScalarExpr::column(0);
1850        while old_self != *self {
1851            old_self = self.clone();
1852            self.reduce_and_canonicalize_and_or(); // We don't want to deal with 1-arg AND/OR at the top
1853            if let MirScalarExpr::CallVariadic {
1854                exprs: outer_operands,
1855                func: outer_func @ (VariadicFunc::Or(_) | VariadicFunc::And(_)),
1856            } = self
1857            {
1858                let inner_func = outer_func.switch_and_or();
1859
1860                // Make sure that each outer operand is a call to inner_func, by wrapping in a 1-arg
1861                // call if necessary.
1862                outer_operands.iter_mut().for_each(|o| {
1863                    if !matches!(o, MirScalarExpr::CallVariadic {func: f, ..} if *f == inner_func) {
1864                        *o = MirScalarExpr::CallVariadic {
1865                            func: inner_func.clone(),
1866                            exprs: vec![o.take()],
1867                        };
1868                    }
1869                });
1870
1871                let mut inner_operands_refs: Vec<&mut Vec<MirScalarExpr>> = outer_operands
1872                    .iter_mut()
1873                    .map(|o| match o {
1874                        MirScalarExpr::CallVariadic { func: f, exprs } if *f == inner_func => exprs,
1875                        _ => unreachable!(), // the wrapping made sure that we'll get a match
1876                    })
1877                    .collect();
1878
1879                // Find inner operands to undistribute, i.e., which are in _all_ of the outer operands.
1880                let mut intersection = inner_operands_refs
1881                    .iter()
1882                    .map(|v| (*v).clone())
1883                    .reduce(|ops1, ops2| ops1.into_iter().filter(|e| ops2.contains(e)).collect())
1884                    .unwrap();
1885                intersection.sort();
1886                intersection.dedup();
1887
1888                if !intersection.is_empty() {
1889                    // Factor out the intersection from all the top-level args.
1890
1891                    // Remove the intersection from each inner operand vector.
1892                    inner_operands_refs
1893                        .iter_mut()
1894                        .for_each(|ops| (**ops).retain(|o| !intersection.contains(o)));
1895
1896                    // Simplify terms that now have only 0 or 1 args due to removing the intersection.
1897                    outer_operands
1898                        .iter_mut()
1899                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1900
1901                    // Add the intersection at the beginning
1902                    *self = MirScalarExpr::CallVariadic {
1903                        func: inner_func,
1904                        exprs: intersection.into_iter().chain_one(self.clone()).collect(),
1905                    };
1906                } else {
1907                    // If the intersection was empty, that means that there is nothing we can factor out
1908                    // from _all_ the top-level args. However, we might still find something to factor
1909                    // out from a subset of the top-level args. To find such an opportunity, we look for
1910                    // duplicates across all inner args, e.g. if we have
1911                    // `(...) OR (... AND `a` AND ...) OR (...) OR (... AND `a` AND ...)`
1912                    // then we'll find that `a` occurs in more than one top-level arg, so
1913                    // `indexes_to_undistribute` will point us to the 2. and 4. top-level args.
1914
1915                    // Create (inner_operand, index) pairs, where the index is the position in
1916                    // outer_operands
1917                    let all_inner_operands = inner_operands_refs
1918                        .iter()
1919                        .enumerate()
1920                        .flat_map(|(i, inner_vec)| inner_vec.iter().map(move |a| ((*a).clone(), i)))
1921                        .sorted()
1922                        .collect_vec();
1923
1924                    // Find inner operand expressions that occur in more than one top-level arg.
1925                    // Each inner vector in `undistribution_opportunities` will belong to one such inner
1926                    // operand expression, and it is a set of indexes pointing to top-level args where
1927                    // that inner operand occurs.
1928                    let undistribution_opportunities = all_inner_operands
1929                        .iter()
1930                        .chunk_by(|(a, _i)| a)
1931                        .into_iter()
1932                        .map(|(_a, g)| g.map(|(_a, i)| *i).sorted().dedup().collect_vec())
1933                        .filter(|g| g.len() > 1)
1934                        .collect_vec();
1935
1936                    // Choose one of the inner vectors from `undistribution_opportunities`.
1937                    let indexes_to_undistribute = undistribution_opportunities
1938                        .iter()
1939                        // Let's prefer index sets that directly lead to an absorption.
1940                        .find(|index_set| {
1941                            index_set
1942                                .iter()
1943                                .any(|i| inner_operands_refs.get(*i).unwrap().len() == 1)
1944                        })
1945                        // If we didn't find any absorption, then any index set will do.
1946                        .or_else(|| undistribution_opportunities.first())
1947                        .cloned();
1948
1949                    // In any case, undo the 1-arg wrapping that we did at the beginning.
1950                    outer_operands
1951                        .iter_mut()
1952                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1953
1954                    if let Some(indexes_to_undistribute) = indexes_to_undistribute {
1955                        // Found something to undistribute from a subset of the outer operands.
1956                        // We temporarily remove these from outer_operands, call ourselves on it, and
1957                        // then push back the result.
1958                        let mut undistribute_from = MirScalarExpr::CallVariadic {
1959                            func: outer_func.clone(),
1960                            exprs: swap_remove_multiple(outer_operands, indexes_to_undistribute),
1961                        };
1962                        // By construction, the recursive call is guaranteed to hit
1963                        // the `!intersection.is_empty()` branch.
1964                        undistribute_from.undistribute_and_or();
1965                        // Append the undistributed result to outer operands that were not included in
1966                        // indexes_to_undistribute.
1967                        outer_operands.push(undistribute_from);
1968                    }
1969                }
1970            }
1971        }
1972    }
1973
1974    /* #endregion */
1975
1976    /// Adds any columns that *must* be non-Null for `self` to be non-Null.
1977    pub fn non_null_requirements(&self, columns: &mut BTreeSet<usize>) {
1978        match self {
1979            MirScalarExpr::Column(col, _name) => {
1980                columns.insert(*col);
1981            }
1982            MirScalarExpr::Literal(..) => {}
1983            MirScalarExpr::CallUnmaterializable(_) => (),
1984            MirScalarExpr::CallUnary { func, expr } => {
1985                if func.propagates_nulls() {
1986                    expr.non_null_requirements(columns);
1987                }
1988            }
1989            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1990                if func.propagates_nulls() {
1991                    expr1.non_null_requirements(columns);
1992                    expr2.non_null_requirements(columns);
1993                }
1994            }
1995            MirScalarExpr::CallVariadic { func, exprs } => {
1996                if func.propagates_nulls() {
1997                    for expr in exprs {
1998                        expr.non_null_requirements(columns);
1999                    }
2000                }
2001            }
2002            MirScalarExpr::If { .. } => (),
2003        }
2004    }
2005
2006    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2007        let repr_column_types = column_types.iter().map(ReprColumnType::from).collect_vec();
2008        SqlColumnType::from_repr(&self.typ(&repr_column_types))
2009    }
2010
2011    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2012        match self {
2013            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
2014            MirScalarExpr::Literal(_, typ) => typ.clone(),
2015            MirScalarExpr::CallUnmaterializable(func) => func.output_type(),
2016            MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)),
2017            MirScalarExpr::CallBinary { expr1, expr2, func } => {
2018                func.output_type(&[expr1.typ(column_types), expr2.typ(column_types)])
2019            }
2020            MirScalarExpr::CallVariadic { exprs, func } => {
2021                func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect())
2022            }
2023            MirScalarExpr::If { cond: _, then, els } => {
2024                let then_type = then.typ(column_types);
2025                let else_type = els.typ(column_types);
2026                then_type.union(&else_type).unwrap()
2027            }
2028        }
2029    }
2030
2031    pub fn eval<'a>(
2032        &'a self,
2033        datums: &[Datum<'a>],
2034        temp_storage: &'a RowArena,
2035    ) -> Result<Datum<'a>, EvalError> {
2036        match self {
2037            MirScalarExpr::Column(index, _name) => Ok(datums[*index]),
2038            MirScalarExpr::Literal(res, _column_type) => match res {
2039                Ok(row) => Ok(row.unpack_first()),
2040                Err(e) => Err(e.clone()),
2041            },
2042            // Unmaterializable functions must be transformed away before
2043            // evaluation. Their purpose is as a placeholder for data that is
2044            // not known at plan time but can be inlined before runtime.
2045            MirScalarExpr::CallUnmaterializable(x) => Err(EvalError::Internal(
2046                format!("cannot evaluate unmaterializable function: {:?}", x).into(),
2047            )),
2048            MirScalarExpr::CallUnary { func, expr } => func.eval(datums, temp_storage, expr),
2049            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2050                func.eval(datums, temp_storage, &[expr1, expr2])
2051            }
2052            MirScalarExpr::CallVariadic { func, exprs } => func.eval(datums, temp_storage, exprs),
2053            MirScalarExpr::If { cond, then, els } => match cond.eval(datums, temp_storage)? {
2054                Datum::True => then.eval(datums, temp_storage),
2055                Datum::False | Datum::Null => els.eval(datums, temp_storage),
2056                d => Err(EvalError::Internal(
2057                    format!("if condition evaluated to non-boolean datum: {:?}", d).into(),
2058                )),
2059            },
2060        }
2061    }
2062
2063    /// True iff the expression contains
2064    /// `UnmaterializableFunc::MzNow`.
2065    pub fn contains_temporal(&self) -> bool {
2066        let mut contains = false;
2067        self.visit_pre(|e| {
2068            if let MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2069                contains = true;
2070            }
2071        });
2072        contains
2073    }
2074
2075    /// True iff the expression contains an `UnmaterializableFunc`.
2076    pub fn contains_unmaterializable(&self) -> bool {
2077        let mut contains = false;
2078        self.visit_pre(|e| {
2079            if let MirScalarExpr::CallUnmaterializable(_) = e {
2080                contains = true;
2081            }
2082        });
2083        contains
2084    }
2085
2086    /// True iff the expression contains an `UnmaterializableFunc` that is not in the `exceptions`
2087    /// list.
2088    pub fn contains_unmaterializable_except(&self, exceptions: &[UnmaterializableFunc]) -> bool {
2089        let mut contains = false;
2090        self.visit_pre(|e| match e {
2091            MirScalarExpr::CallUnmaterializable(f) if !exceptions.contains(f) => contains = true,
2092            _ => (),
2093        });
2094        contains
2095    }
2096
2097    /// True iff the expression contains a `Column`.
2098    pub fn contains_column(&self) -> bool {
2099        let mut contains = false;
2100        self.visit_pre(|e| {
2101            if let MirScalarExpr::Column(_col, _name) = e {
2102                contains = true;
2103            }
2104        });
2105        contains
2106    }
2107
2108    /// True iff the expression contains a `Dummy`.
2109    pub fn contains_dummy(&self) -> bool {
2110        let mut contains = false;
2111        self.visit_pre(|e| {
2112            if let MirScalarExpr::Literal(row, _) = e {
2113                if let Ok(row) = row {
2114                    contains |= row.iter().any(|d| d.contains_dummy());
2115                }
2116            }
2117        });
2118        contains
2119    }
2120
2121    /// The size of the expression as a tree.
2122    pub fn size(&self) -> usize {
2123        let mut size = 0;
2124        self.visit_pre(&mut |_: &MirScalarExpr| {
2125            size += 1;
2126        });
2127        size
2128    }
2129}
2130
2131impl MirScalarExpr {
2132    /// True iff evaluation could possibly error on non-error input `Datum`.
2133    pub fn could_error(&self) -> bool {
2134        match self {
2135            MirScalarExpr::Column(_col, _name) => false,
2136            MirScalarExpr::Literal(row, ..) => row.is_err(),
2137            MirScalarExpr::CallUnmaterializable(_) => true,
2138            MirScalarExpr::CallUnary { func, expr } => func.could_error() || expr.could_error(),
2139            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2140                func.could_error() || expr1.could_error() || expr2.could_error()
2141            }
2142            MirScalarExpr::CallVariadic { func, exprs } => {
2143                func.could_error() || exprs.iter().any(|e| e.could_error())
2144            }
2145            MirScalarExpr::If { cond, then, els } => {
2146                cond.could_error() || then.could_error() || els.could_error()
2147            }
2148        }
2149    }
2150}
2151
2152impl VisitChildren<Self> for MirScalarExpr {
2153    fn visit_children<F>(&self, mut f: F)
2154    where
2155        F: FnMut(&Self),
2156    {
2157        use MirScalarExpr::*;
2158        match self {
2159            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2160            CallUnary { expr, .. } => {
2161                f(expr);
2162            }
2163            CallBinary { expr1, expr2, .. } => {
2164                f(expr1);
2165                f(expr2);
2166            }
2167            CallVariadic { exprs, .. } => {
2168                for expr in exprs {
2169                    f(expr);
2170                }
2171            }
2172            If { cond, then, els } => {
2173                f(cond);
2174                f(then);
2175                f(els);
2176            }
2177        }
2178    }
2179
2180    fn visit_mut_children<F>(&mut self, mut f: F)
2181    where
2182        F: FnMut(&mut Self),
2183    {
2184        use MirScalarExpr::*;
2185        match self {
2186            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2187            CallUnary { expr, .. } => {
2188                f(expr);
2189            }
2190            CallBinary { expr1, expr2, .. } => {
2191                f(expr1);
2192                f(expr2);
2193            }
2194            CallVariadic { exprs, .. } => {
2195                for expr in exprs {
2196                    f(expr);
2197                }
2198            }
2199            If { cond, then, els } => {
2200                f(cond);
2201                f(then);
2202                f(els);
2203            }
2204        }
2205    }
2206
2207    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2208    where
2209        F: FnMut(&Self) -> Result<(), E>,
2210        E: From<RecursionLimitError>,
2211    {
2212        use MirScalarExpr::*;
2213        match self {
2214            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2215            CallUnary { expr, .. } => {
2216                f(expr)?;
2217            }
2218            CallBinary { expr1, expr2, .. } => {
2219                f(expr1)?;
2220                f(expr2)?;
2221            }
2222            CallVariadic { exprs, .. } => {
2223                for expr in exprs {
2224                    f(expr)?;
2225                }
2226            }
2227            If { cond, then, els } => {
2228                f(cond)?;
2229                f(then)?;
2230                f(els)?;
2231            }
2232        }
2233        Ok(())
2234    }
2235
2236    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2237    where
2238        F: FnMut(&mut Self) -> Result<(), E>,
2239        E: From<RecursionLimitError>,
2240    {
2241        use MirScalarExpr::*;
2242        match self {
2243            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2244            CallUnary { expr, .. } => {
2245                f(expr)?;
2246            }
2247            CallBinary { expr1, expr2, .. } => {
2248                f(expr1)?;
2249                f(expr2)?;
2250            }
2251            CallVariadic { exprs, .. } => {
2252                for expr in exprs {
2253                    f(expr)?;
2254                }
2255            }
2256            If { cond, then, els } => {
2257                f(cond)?;
2258                f(then)?;
2259                f(els)?;
2260            }
2261        }
2262        Ok(())
2263    }
2264}
2265
2266impl MirScalarExpr {
2267    /// Iterates through references to child expressions.
2268    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2269        let mut first = None;
2270        let mut second = None;
2271        let mut third = None;
2272        let mut variadic = None;
2273
2274        use MirScalarExpr::*;
2275        match self {
2276            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2277            CallUnary { expr, .. } => {
2278                first = Some(&**expr);
2279            }
2280            CallBinary { expr1, expr2, .. } => {
2281                first = Some(&**expr1);
2282                second = Some(&**expr2);
2283            }
2284            CallVariadic { exprs, .. } => {
2285                variadic = Some(exprs);
2286            }
2287            If { cond, then, els } => {
2288                first = Some(&**cond);
2289                second = Some(&**then);
2290                third = Some(&**els);
2291            }
2292        }
2293
2294        first
2295            .into_iter()
2296            .chain(second)
2297            .chain(third)
2298            .chain(variadic.into_iter().flatten())
2299    }
2300
2301    /// Iterates through mutable references to child expressions.
2302    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2303        let mut first = None;
2304        let mut second = None;
2305        let mut third = None;
2306        let mut variadic = None;
2307
2308        use MirScalarExpr::*;
2309        match self {
2310            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2311            CallUnary { expr, .. } => {
2312                first = Some(&mut **expr);
2313            }
2314            CallBinary { expr1, expr2, .. } => {
2315                first = Some(&mut **expr1);
2316                second = Some(&mut **expr2);
2317            }
2318            CallVariadic { exprs, .. } => {
2319                variadic = Some(exprs);
2320            }
2321            If { cond, then, els } => {
2322                first = Some(&mut **cond);
2323                second = Some(&mut **then);
2324                third = Some(&mut **els);
2325            }
2326        }
2327
2328        first
2329            .into_iter()
2330            .chain(second)
2331            .chain(third)
2332            .chain(variadic.into_iter().flatten())
2333    }
2334
2335    /// Visits all subexpressions in DFS preorder.
2336    pub fn visit_pre<F>(&self, mut f: F)
2337    where
2338        F: FnMut(&Self),
2339    {
2340        let mut worklist = vec![self];
2341        while let Some(e) = worklist.pop() {
2342            f(e);
2343            worklist.extend(e.children().rev());
2344        }
2345    }
2346
2347    /// Iterative pre-order visitor.
2348    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2349        let mut worklist = vec![self];
2350        while let Some(expr) = worklist.pop() {
2351            f(expr);
2352            worklist.extend(expr.children_mut().rev());
2353        }
2354    }
2355}
2356
2357/// Filter characteristics that are used for ordering join inputs.
2358/// This can be created for a `Vec<MirScalarExpr>`, which represents an AND of predicates.
2359///
2360/// The fields are ordered based on heuristic assumptions about their typical selectivity, so that
2361/// Ord gives the right ordering for join inputs. Bigger is better, i.e., will tend to come earlier
2362/// than other inputs.
2363#[derive(
2364    Eq,
2365    PartialEq,
2366    Ord,
2367    PartialOrd,
2368    Debug,
2369    Clone,
2370    Serialize,
2371    Deserialize,
2372    Hash,
2373    MzReflect
2374)]
2375pub struct FilterCharacteristics {
2376    // `<expr> = <literal>` appears in the filter.
2377    // Excludes cases where NOT appears anywhere above the literal equality.
2378    literal_equality: bool,
2379    // (Assuming a random string of lower-case characters, `LIKE 'a%'` has a selectivity of 1/26.)
2380    like: bool,
2381    is_null: bool,
2382    // Number of Vec elements that involve inequality predicates. (A BETWEEN is represented as two
2383    // inequality predicates.)
2384    // Excludes cases where NOT appears around the literal inequality.
2385    // Note that for inequality predicates, some databases assume 1/3 selectivity in the absence of
2386    // concrete statistics.
2387    literal_inequality: usize,
2388    /// Any filter, except ones involving `IS NOT NULL`, because those are too common.
2389    /// Can be true by itself, or any other field being true can also make this true.
2390    /// `NOT LIKE` is only in this category.
2391    /// `!=` is only in this category.
2392    /// `NOT (a = b)` is turned into `!=` by `reduce` before us!
2393    any_filter: bool,
2394}
2395
2396impl BitOrAssign for FilterCharacteristics {
2397    fn bitor_assign(&mut self, rhs: Self) {
2398        self.literal_equality |= rhs.literal_equality;
2399        self.like |= rhs.like;
2400        self.is_null |= rhs.is_null;
2401        self.literal_inequality += rhs.literal_inequality;
2402        self.any_filter |= rhs.any_filter;
2403    }
2404}
2405
2406impl FilterCharacteristics {
2407    pub fn none() -> FilterCharacteristics {
2408        FilterCharacteristics {
2409            literal_equality: false,
2410            like: false,
2411            is_null: false,
2412            literal_inequality: 0,
2413            any_filter: false,
2414        }
2415    }
2416
2417    pub fn explain(&self) -> String {
2418        let mut e = "".to_owned();
2419        if self.literal_equality {
2420            e.push_str("e");
2421        }
2422        if self.like {
2423            e.push_str("l");
2424        }
2425        if self.is_null {
2426            e.push_str("n");
2427        }
2428        for _ in 0..self.literal_inequality {
2429            e.push_str("i");
2430        }
2431        if self.any_filter {
2432            e.push_str("f");
2433        }
2434        e
2435    }
2436
2437    pub fn filter_characteristics(
2438        filters: &Vec<MirScalarExpr>,
2439    ) -> Result<FilterCharacteristics, RecursionLimitError> {
2440        let mut literal_equality = false;
2441        let mut like = false;
2442        let mut is_null = false;
2443        let mut literal_inequality = 0;
2444        let mut any_filter = false;
2445        filters.iter().try_for_each(|f| {
2446            let mut literal_inequality_in_current_filter = false;
2447            let mut is_not_null_in_current_filter = false;
2448            f.visit_pre_with_context(
2449                false,
2450                &mut |not_in_parent_chain, expr| {
2451                    not_in_parent_chain
2452                        || matches!(
2453                            expr,
2454                            MirScalarExpr::CallUnary {
2455                                func: UnaryFunc::Not(func::Not),
2456                                ..
2457                            }
2458                        )
2459                },
2460                &mut |not_in_parent_chain, expr| {
2461                    if !not_in_parent_chain {
2462                        if expr.any_expr_eq_literal().is_some() {
2463                            literal_equality = true;
2464                        }
2465                        if expr.any_expr_ineq_literal() {
2466                            literal_inequality_in_current_filter = true;
2467                        }
2468                        if matches!(
2469                            expr,
2470                            MirScalarExpr::CallUnary {
2471                                func: UnaryFunc::IsLikeMatch(_),
2472                                ..
2473                            }
2474                        ) {
2475                            like = true;
2476                        }
2477                    };
2478                    if matches!(
2479                        expr,
2480                        MirScalarExpr::CallUnary {
2481                            func: UnaryFunc::IsNull(crate::func::IsNull),
2482                            ..
2483                        }
2484                    ) {
2485                        if *not_in_parent_chain {
2486                            is_not_null_in_current_filter = true;
2487                        } else {
2488                            is_null = true;
2489                        }
2490                    }
2491                },
2492            )?;
2493            if literal_inequality_in_current_filter {
2494                literal_inequality += 1;
2495            }
2496            if !is_not_null_in_current_filter {
2497                // We want to ignore `IS NOT NULL` for `any_filter`.
2498                any_filter = true;
2499            }
2500            Ok(())
2501        })?;
2502        Ok(FilterCharacteristics {
2503            literal_equality,
2504            like,
2505            is_null,
2506            literal_inequality,
2507            any_filter,
2508        })
2509    }
2510
2511    pub fn add_literal_equality(&mut self) {
2512        self.literal_equality = true;
2513    }
2514
2515    pub fn worst_case_scaling_factor(&self) -> f64 {
2516        let mut factor = 1.0;
2517
2518        if self.literal_equality {
2519            factor *= 0.1;
2520        }
2521
2522        if self.is_null {
2523            factor *= 0.1;
2524        }
2525
2526        if self.literal_inequality >= 2 {
2527            factor *= 0.25;
2528        } else if self.literal_inequality == 1 {
2529            factor *= 0.33;
2530        }
2531
2532        // catch various negated filters, treat them pessimistically
2533        if !(self.literal_equality || self.is_null || self.literal_inequality > 0)
2534            && self.any_filter
2535        {
2536            factor *= 0.9;
2537        }
2538
2539        factor
2540    }
2541}
2542
2543#[derive(
2544    Arbitrary,
2545    Ord,
2546    PartialOrd,
2547    Copy,
2548    Clone,
2549    Debug,
2550    Eq,
2551    PartialEq,
2552    Serialize,
2553    Deserialize,
2554    Hash,
2555    MzReflect
2556)]
2557pub enum DomainLimit {
2558    None,
2559    Inclusive(i64),
2560    Exclusive(i64),
2561}
2562
2563impl RustType<ProtoDomainLimit> for DomainLimit {
2564    fn into_proto(&self) -> ProtoDomainLimit {
2565        use proto_domain_limit::Kind::*;
2566        let kind = match self {
2567            DomainLimit::None => None(()),
2568            DomainLimit::Inclusive(v) => Inclusive(*v),
2569            DomainLimit::Exclusive(v) => Exclusive(*v),
2570        };
2571        ProtoDomainLimit { kind: Some(kind) }
2572    }
2573
2574    fn from_proto(proto: ProtoDomainLimit) -> Result<Self, TryFromProtoError> {
2575        use proto_domain_limit::Kind::*;
2576        if let Some(kind) = proto.kind {
2577            match kind {
2578                None(()) => Ok(DomainLimit::None),
2579                Inclusive(v) => Ok(DomainLimit::Inclusive(v)),
2580                Exclusive(v) => Ok(DomainLimit::Exclusive(v)),
2581            }
2582        } else {
2583            Err(TryFromProtoError::missing_field("ProtoDomainLimit::kind"))
2584        }
2585    }
2586}
2587
2588#[derive(
2589    Arbitrary,
2590    Ord,
2591    PartialOrd,
2592    Clone,
2593    Debug,
2594    Eq,
2595    PartialEq,
2596    Serialize,
2597    Deserialize,
2598    Hash,
2599    MzReflect
2600)]
2601pub enum EvalError {
2602    CharacterNotValidForEncoding(i32),
2603    CharacterTooLargeForEncoding(i32),
2604    DateBinOutOfRange(Box<str>),
2605    DivisionByZero,
2606    Unsupported {
2607        feature: Box<str>,
2608        discussion_no: Option<usize>,
2609    },
2610    FloatOverflow,
2611    FloatUnderflow,
2612    NumericFieldOverflow,
2613    Float32OutOfRange(Box<str>),
2614    Float64OutOfRange(Box<str>),
2615    Int16OutOfRange(Box<str>),
2616    Int32OutOfRange(Box<str>),
2617    Int64OutOfRange(Box<str>),
2618    UInt16OutOfRange(Box<str>),
2619    UInt32OutOfRange(Box<str>),
2620    UInt64OutOfRange(Box<str>),
2621    MzTimestampOutOfRange(Box<str>),
2622    MzTimestampStepOverflow,
2623    OidOutOfRange(Box<str>),
2624    IntervalOutOfRange(Box<str>),
2625    TimestampCannotBeNan,
2626    TimestampOutOfRange,
2627    DateOutOfRange,
2628    CharOutOfRange,
2629    IndexOutOfRange {
2630        provided: i32,
2631        // The last valid index position, i.e. `v.len() - 1`
2632        valid_end: i32,
2633    },
2634    InvalidBase64Equals,
2635    InvalidBase64Symbol(char),
2636    InvalidBase64EndSequence,
2637    InvalidTimezone(Box<str>),
2638    InvalidTimezoneInterval,
2639    InvalidTimezoneConversion,
2640    InvalidIanaTimezoneId(Box<str>),
2641    InvalidLayer {
2642        max_layer: usize,
2643        val: i64,
2644    },
2645    InvalidArray(InvalidArrayError),
2646    InvalidEncodingName(Box<str>),
2647    InvalidHashAlgorithm(Box<str>),
2648    InvalidByteSequence {
2649        byte_sequence: Box<str>,
2650        encoding_name: Box<str>,
2651    },
2652    InvalidJsonbCast {
2653        from: Box<str>,
2654        to: Box<str>,
2655    },
2656    InvalidRegex(Box<str>),
2657    InvalidRegexFlag(char),
2658    InvalidParameterValue(Box<str>),
2659    InvalidDatePart(Box<str>),
2660    KeyCannotBeNull,
2661    NegSqrt,
2662    NegLimit,
2663    NullCharacterNotPermitted,
2664    UnknownUnits(Box<str>),
2665    UnsupportedUnits(Box<str>, Box<str>),
2666    UnterminatedLikeEscapeSequence,
2667    Parse(ParseError),
2668    ParseHex(ParseHexError),
2669    Internal(Box<str>),
2670    InfinityOutOfDomain(Box<str>),
2671    NegativeOutOfDomain(Box<str>),
2672    ZeroOutOfDomain(Box<str>),
2673    OutOfDomain(DomainLimit, DomainLimit, Box<str>),
2674    ComplexOutOfRange(Box<str>),
2675    MultipleRowsFromSubquery,
2676    Undefined(Box<str>),
2677    LikePatternTooLong,
2678    LikeEscapeTooLong,
2679    StringValueTooLong {
2680        target_type: Box<str>,
2681        length: usize,
2682    },
2683    MultidimensionalArrayRemovalNotSupported,
2684    IncompatibleArrayDimensions {
2685        dims: Option<(usize, usize)>,
2686    },
2687    TypeFromOid(Box<str>),
2688    InvalidRange(InvalidRangeError),
2689    InvalidRoleId(Box<str>),
2690    InvalidPrivileges(Box<str>),
2691    LetRecLimitExceeded(Box<str>),
2692    MultiDimensionalArraySearch,
2693    MustNotBeNull(Box<str>),
2694    InvalidIdentifier {
2695        ident: Box<str>,
2696        detail: Option<Box<str>>,
2697    },
2698    ArrayFillWrongArraySubscripts,
2699    // TODO: propagate this check more widely throughout the expr crate
2700    MaxArraySizeExceeded(usize),
2701    DateDiffOverflow {
2702        unit: Box<str>,
2703        a: Box<str>,
2704        b: Box<str>,
2705    },
2706    // The error for ErrorIfNull; this should not be used in other contexts as a generic error
2707    // printer.
2708    IfNullError(Box<str>),
2709    LengthTooLarge,
2710    AclArrayNullElement,
2711    MzAclArrayNullElement,
2712    PrettyError(Box<str>),
2713}
2714
2715impl fmt::Display for EvalError {
2716    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2717        match self {
2718            EvalError::CharacterNotValidForEncoding(v) => {
2719                write!(f, "requested character not valid for encoding: {v}")
2720            }
2721            EvalError::CharacterTooLargeForEncoding(v) => {
2722                write!(f, "requested character too large for encoding: {v}")
2723            }
2724            EvalError::DateBinOutOfRange(message) => f.write_str(message),
2725            EvalError::DivisionByZero => f.write_str("division by zero"),
2726            EvalError::Unsupported {
2727                feature,
2728                discussion_no,
2729            } => {
2730                write!(f, "{} not yet supported", feature)?;
2731                if let Some(discussion_no) = discussion_no {
2732                    write!(
2733                        f,
2734                        ", see https://github.com/MaterializeInc/materialize/discussions/{} for more details",
2735                        discussion_no
2736                    )?;
2737                }
2738                Ok(())
2739            }
2740            EvalError::FloatOverflow => f.write_str("value out of range: overflow"),
2741            EvalError::FloatUnderflow => f.write_str("value out of range: underflow"),
2742            EvalError::NumericFieldOverflow => f.write_str("numeric field overflow"),
2743            EvalError::Float32OutOfRange(val) => write!(f, "{} real out of range", val.quoted()),
2744            EvalError::Float64OutOfRange(val) => {
2745                write!(f, "{} double precision out of range", val.quoted())
2746            }
2747            EvalError::Int16OutOfRange(val) => write!(f, "{} smallint out of range", val.quoted()),
2748            EvalError::Int32OutOfRange(val) => write!(f, "{} integer out of range", val.quoted()),
2749            EvalError::Int64OutOfRange(val) => write!(f, "{} bigint out of range", val.quoted()),
2750            EvalError::UInt16OutOfRange(val) => write!(f, "{} uint2 out of range", val.quoted()),
2751            EvalError::UInt32OutOfRange(val) => write!(f, "{} uint4 out of range", val.quoted()),
2752            EvalError::UInt64OutOfRange(val) => write!(f, "{} uint8 out of range", val.quoted()),
2753            EvalError::MzTimestampOutOfRange(val) => {
2754                write!(f, "{} mz_timestamp out of range", val.quoted())
2755            }
2756            EvalError::MzTimestampStepOverflow => f.write_str("step mz_timestamp overflow"),
2757            EvalError::OidOutOfRange(val) => write!(f, "{} OID out of range", val.quoted()),
2758            EvalError::IntervalOutOfRange(val) => {
2759                write!(f, "{} interval out of range", val.quoted())
2760            }
2761            EvalError::TimestampCannotBeNan => f.write_str("timestamp cannot be NaN"),
2762            EvalError::TimestampOutOfRange => f.write_str("timestamp out of range"),
2763            EvalError::DateOutOfRange => f.write_str("date out of range"),
2764            EvalError::CharOutOfRange => f.write_str("\"char\" out of range"),
2765            EvalError::IndexOutOfRange {
2766                provided,
2767                valid_end,
2768            } => write!(f, "index {provided} out of valid range, 0..{valid_end}",),
2769            EvalError::InvalidBase64Equals => {
2770                f.write_str("unexpected \"=\" while decoding base64 sequence")
2771            }
2772            EvalError::InvalidBase64Symbol(c) => write!(
2773                f,
2774                "invalid symbol \"{}\" found while decoding base64 sequence",
2775                c.escape_default()
2776            ),
2777            EvalError::InvalidBase64EndSequence => f.write_str("invalid base64 end sequence"),
2778            EvalError::InvalidJsonbCast { from, to } => {
2779                write!(f, "cannot cast jsonb {} to type {}", from, to)
2780            }
2781            EvalError::InvalidTimezone(tz) => write!(f, "invalid time zone '{}'", tz),
2782            EvalError::InvalidTimezoneInterval => {
2783                f.write_str("timezone interval must not contain months or years")
2784            }
2785            EvalError::InvalidTimezoneConversion => f.write_str("invalid timezone conversion"),
2786            EvalError::InvalidIanaTimezoneId(tz) => {
2787                write!(f, "invalid IANA Time Zone Database identifier: '{}'", tz)
2788            }
2789            EvalError::InvalidLayer { max_layer, val } => write!(
2790                f,
2791                "invalid layer: {}; must use value within [1, {}]",
2792                val, max_layer
2793            ),
2794            EvalError::InvalidArray(e) => e.fmt(f),
2795            EvalError::InvalidEncodingName(name) => write!(f, "invalid encoding name '{}'", name),
2796            EvalError::InvalidHashAlgorithm(alg) => write!(f, "invalid hash algorithm '{}'", alg),
2797            EvalError::InvalidByteSequence {
2798                byte_sequence,
2799                encoding_name,
2800            } => write!(
2801                f,
2802                "invalid byte sequence '{}' for encoding '{}'",
2803                byte_sequence, encoding_name
2804            ),
2805            EvalError::InvalidDatePart(part) => write!(f, "invalid datepart {}", part.quoted()),
2806            EvalError::KeyCannotBeNull => f.write_str("key cannot be null"),
2807            EvalError::NegSqrt => f.write_str("cannot take square root of a negative number"),
2808            EvalError::NegLimit => f.write_str("LIMIT must not be negative"),
2809            EvalError::NullCharacterNotPermitted => f.write_str("null character not permitted"),
2810            EvalError::InvalidRegex(e) => write!(f, "invalid regular expression: {}", e),
2811            EvalError::InvalidRegexFlag(c) => write!(f, "invalid regular expression flag: {}", c),
2812            EvalError::InvalidParameterValue(s) => f.write_str(s),
2813            EvalError::UnknownUnits(units) => write!(f, "unit '{}' not recognized", units),
2814            EvalError::UnsupportedUnits(units, typ) => {
2815                write!(f, "unit '{}' not supported for type {}", units, typ)
2816            }
2817            EvalError::UnterminatedLikeEscapeSequence => {
2818                f.write_str("unterminated escape sequence in LIKE")
2819            }
2820            EvalError::Parse(e) => e.fmt(f),
2821            EvalError::PrettyError(e) => e.fmt(f),
2822            EvalError::ParseHex(e) => e.fmt(f),
2823            EvalError::Internal(s) => write!(f, "internal error: {}", s),
2824            EvalError::InfinityOutOfDomain(s) => {
2825                write!(f, "function {} is only defined for finite arguments", s)
2826            }
2827            EvalError::NegativeOutOfDomain(s) => {
2828                write!(f, "function {} is not defined for negative numbers", s)
2829            }
2830            EvalError::ZeroOutOfDomain(s) => {
2831                write!(f, "function {} is not defined for zero", s)
2832            }
2833            EvalError::OutOfDomain(lower, upper, s) => {
2834                use DomainLimit::*;
2835                write!(f, "function {s} is defined for numbers ")?;
2836                match (lower, upper) {
2837                    (Inclusive(n), None) => write!(f, "greater than or equal to {n}"),
2838                    (Exclusive(n), None) => write!(f, "greater than {n}"),
2839                    (None, Inclusive(n)) => write!(f, "less than or equal to {n}"),
2840                    (None, Exclusive(n)) => write!(f, "less than {n}"),
2841                    (Inclusive(lo), Inclusive(hi)) => write!(f, "between {lo} and {hi} inclusive"),
2842                    (Exclusive(lo), Exclusive(hi)) => write!(f, "between {lo} and {hi} exclusive"),
2843                    (Inclusive(lo), Exclusive(hi)) => {
2844                        write!(f, "between {lo} inclusive and {hi} exclusive")
2845                    }
2846                    (Exclusive(lo), Inclusive(hi)) => {
2847                        write!(f, "between {lo} exclusive and {hi} inclusive")
2848                    }
2849                    (None, None) => panic!("invalid domain error"),
2850                }
2851            }
2852            EvalError::ComplexOutOfRange(s) => {
2853                write!(f, "function {} cannot return complex numbers", s)
2854            }
2855            EvalError::MultipleRowsFromSubquery => {
2856                write!(f, "more than one record produced in subquery")
2857            }
2858            EvalError::Undefined(s) => {
2859                write!(f, "{} is undefined", s)
2860            }
2861            EvalError::LikePatternTooLong => {
2862                write!(f, "LIKE pattern exceeds maximum length")
2863            }
2864            EvalError::LikeEscapeTooLong => {
2865                write!(f, "invalid escape string")
2866            }
2867            EvalError::StringValueTooLong {
2868                target_type,
2869                length,
2870            } => {
2871                write!(f, "value too long for type {}({})", target_type, length)
2872            }
2873            EvalError::MultidimensionalArrayRemovalNotSupported => {
2874                write!(
2875                    f,
2876                    "removing elements from multidimensional arrays is not supported"
2877                )
2878            }
2879            EvalError::IncompatibleArrayDimensions { dims: _ } => {
2880                write!(f, "cannot concatenate incompatible arrays")
2881            }
2882            EvalError::TypeFromOid(msg) => write!(f, "{msg}"),
2883            EvalError::InvalidRange(e) => e.fmt(f),
2884            EvalError::InvalidRoleId(msg) => write!(f, "{msg}"),
2885            EvalError::InvalidPrivileges(privilege) => {
2886                write!(f, "unrecognized privilege type: {privilege}")
2887            }
2888            EvalError::LetRecLimitExceeded(max_iters) => {
2889                write!(
2890                    f,
2891                    "Recursive query exceeded the recursion limit {}. (Use RETURN AT RECURSION LIMIT to not error, but return the current state as the final result when reaching the limit.)",
2892                    max_iters
2893                )
2894            }
2895            EvalError::MultiDimensionalArraySearch => write!(
2896                f,
2897                "searching for elements in multidimensional arrays is not supported"
2898            ),
2899            EvalError::MustNotBeNull(v) => write!(f, "{v} must not be null"),
2900            EvalError::InvalidIdentifier { ident, .. } => {
2901                write!(f, "string is not a valid identifier: {}", ident.quoted())
2902            }
2903            EvalError::ArrayFillWrongArraySubscripts => {
2904                f.write_str("wrong number of array subscripts")
2905            }
2906            EvalError::MaxArraySizeExceeded(max_size) => {
2907                write!(
2908                    f,
2909                    "array size exceeds the maximum allowed ({max_size} bytes)"
2910                )
2911            }
2912            EvalError::DateDiffOverflow { unit, a, b } => {
2913                write!(f, "datediff overflow, {unit} of {a}, {b}")
2914            }
2915            EvalError::IfNullError(s) => f.write_str(s),
2916            EvalError::LengthTooLarge => write!(f, "requested length too large"),
2917            EvalError::AclArrayNullElement => write!(f, "ACL arrays must not contain null values"),
2918            EvalError::MzAclArrayNullElement => {
2919                write!(f, "MZ_ACL arrays must not contain null values")
2920            }
2921        }
2922    }
2923}
2924
2925impl EvalError {
2926    pub fn detail(&self) -> Option<String> {
2927        match self {
2928            EvalError::IncompatibleArrayDimensions { dims: None } => Some(
2929                "Arrays with differing dimensions are not compatible for concatenation.".into(),
2930            ),
2931            EvalError::IncompatibleArrayDimensions {
2932                dims: Some((a_dims, b_dims)),
2933            } => Some(format!(
2934                "Arrays of {} and {} dimensions are not compatible for concatenation.",
2935                a_dims, b_dims
2936            )),
2937            EvalError::InvalidIdentifier { detail, .. } => detail.as_deref().map(Into::into),
2938            EvalError::ArrayFillWrongArraySubscripts => {
2939                Some("Low bound array has different size than dimensions array.".into())
2940            }
2941            _ => None,
2942        }
2943    }
2944
2945    pub fn hint(&self) -> Option<String> {
2946        match self {
2947            EvalError::InvalidBase64EndSequence => Some(
2948                "Input data is missing padding, is truncated, or is otherwise corrupted.".into(),
2949            ),
2950            EvalError::LikeEscapeTooLong => {
2951                Some("Escape string must be empty or one character.".into())
2952            }
2953            EvalError::MzTimestampOutOfRange(_) => Some(
2954                "Integer, numeric, and text casts to mz_timestamp must be in the form of whole \
2955                milliseconds since the Unix epoch. Values with fractional parts cannot be \
2956                converted to mz_timestamp."
2957                    .into(),
2958            ),
2959            _ => None,
2960        }
2961    }
2962}
2963
2964impl std::error::Error for EvalError {}
2965
2966impl From<ParseError> for EvalError {
2967    fn from(e: ParseError) -> EvalError {
2968        EvalError::Parse(e)
2969    }
2970}
2971
2972impl From<ParseHexError> for EvalError {
2973    fn from(e: ParseHexError) -> EvalError {
2974        EvalError::ParseHex(e)
2975    }
2976}
2977
2978impl From<InvalidArrayError> for EvalError {
2979    fn from(e: InvalidArrayError) -> EvalError {
2980        EvalError::InvalidArray(e)
2981    }
2982}
2983
2984impl From<RegexCompilationError> for EvalError {
2985    fn from(e: RegexCompilationError) -> EvalError {
2986        EvalError::InvalidRegex(e.to_string().into())
2987    }
2988}
2989
2990impl From<TypeFromOidError> for EvalError {
2991    fn from(e: TypeFromOidError) -> EvalError {
2992        EvalError::TypeFromOid(e.to_string().into())
2993    }
2994}
2995
2996impl From<DateError> for EvalError {
2997    fn from(e: DateError) -> EvalError {
2998        match e {
2999            DateError::OutOfRange => EvalError::DateOutOfRange,
3000        }
3001    }
3002}
3003
3004impl From<TimestampError> for EvalError {
3005    fn from(e: TimestampError) -> EvalError {
3006        match e {
3007            TimestampError::OutOfRange => EvalError::TimestampOutOfRange,
3008        }
3009    }
3010}
3011
3012impl From<InvalidRangeError> for EvalError {
3013    fn from(e: InvalidRangeError) -> EvalError {
3014        EvalError::InvalidRange(e)
3015    }
3016}
3017
3018impl RustType<ProtoEvalError> for EvalError {
3019    fn into_proto(&self) -> ProtoEvalError {
3020        use proto_eval_error::Kind::*;
3021        use proto_eval_error::*;
3022        let kind = match self {
3023            EvalError::CharacterNotValidForEncoding(v) => CharacterNotValidForEncoding(*v),
3024            EvalError::CharacterTooLargeForEncoding(v) => CharacterTooLargeForEncoding(*v),
3025            EvalError::DateBinOutOfRange(v) => DateBinOutOfRange(v.into_proto()),
3026            EvalError::DivisionByZero => DivisionByZero(()),
3027            EvalError::Unsupported {
3028                feature,
3029                discussion_no,
3030            } => Unsupported(ProtoUnsupported {
3031                feature: feature.into_proto(),
3032                discussion_no: discussion_no.into_proto(),
3033            }),
3034            EvalError::FloatOverflow => FloatOverflow(()),
3035            EvalError::FloatUnderflow => FloatUnderflow(()),
3036            EvalError::NumericFieldOverflow => NumericFieldOverflow(()),
3037            EvalError::Float32OutOfRange(val) => Float32OutOfRange(ProtoValueOutOfRange {
3038                value: val.to_string(),
3039            }),
3040            EvalError::Float64OutOfRange(val) => Float64OutOfRange(ProtoValueOutOfRange {
3041                value: val.to_string(),
3042            }),
3043            EvalError::Int16OutOfRange(val) => Int16OutOfRange(ProtoValueOutOfRange {
3044                value: val.to_string(),
3045            }),
3046            EvalError::Int32OutOfRange(val) => Int32OutOfRange(ProtoValueOutOfRange {
3047                value: val.to_string(),
3048            }),
3049            EvalError::Int64OutOfRange(val) => Int64OutOfRange(ProtoValueOutOfRange {
3050                value: val.to_string(),
3051            }),
3052            EvalError::UInt16OutOfRange(val) => Uint16OutOfRange(ProtoValueOutOfRange {
3053                value: val.to_string(),
3054            }),
3055            EvalError::UInt32OutOfRange(val) => Uint32OutOfRange(ProtoValueOutOfRange {
3056                value: val.to_string(),
3057            }),
3058            EvalError::UInt64OutOfRange(val) => Uint64OutOfRange(ProtoValueOutOfRange {
3059                value: val.to_string(),
3060            }),
3061            EvalError::MzTimestampOutOfRange(val) => MzTimestampOutOfRange(ProtoValueOutOfRange {
3062                value: val.to_string(),
3063            }),
3064            EvalError::MzTimestampStepOverflow => MzTimestampStepOverflow(()),
3065            EvalError::OidOutOfRange(val) => OidOutOfRange(ProtoValueOutOfRange {
3066                value: val.to_string(),
3067            }),
3068            EvalError::IntervalOutOfRange(val) => IntervalOutOfRange(ProtoValueOutOfRange {
3069                value: val.to_string(),
3070            }),
3071            EvalError::TimestampCannotBeNan => TimestampCannotBeNan(()),
3072            EvalError::TimestampOutOfRange => TimestampOutOfRange(()),
3073            EvalError::DateOutOfRange => DateOutOfRange(()),
3074            EvalError::CharOutOfRange => CharOutOfRange(()),
3075            EvalError::IndexOutOfRange {
3076                provided,
3077                valid_end,
3078            } => IndexOutOfRange(ProtoIndexOutOfRange {
3079                provided: *provided,
3080                valid_end: *valid_end,
3081            }),
3082            EvalError::InvalidBase64Equals => InvalidBase64Equals(()),
3083            EvalError::InvalidBase64Symbol(sym) => InvalidBase64Symbol(sym.into_proto()),
3084            EvalError::InvalidBase64EndSequence => InvalidBase64EndSequence(()),
3085            EvalError::InvalidTimezone(tz) => InvalidTimezone(tz.into_proto()),
3086            EvalError::InvalidTimezoneInterval => InvalidTimezoneInterval(()),
3087            EvalError::InvalidTimezoneConversion => InvalidTimezoneConversion(()),
3088            EvalError::InvalidLayer { max_layer, val } => InvalidLayer(ProtoInvalidLayer {
3089                max_layer: max_layer.into_proto(),
3090                val: *val,
3091            }),
3092            EvalError::InvalidArray(error) => InvalidArray(error.into_proto()),
3093            EvalError::InvalidEncodingName(v) => InvalidEncodingName(v.into_proto()),
3094            EvalError::InvalidHashAlgorithm(v) => InvalidHashAlgorithm(v.into_proto()),
3095            EvalError::InvalidByteSequence {
3096                byte_sequence,
3097                encoding_name,
3098            } => InvalidByteSequence(ProtoInvalidByteSequence {
3099                byte_sequence: byte_sequence.into_proto(),
3100                encoding_name: encoding_name.into_proto(),
3101            }),
3102            EvalError::InvalidJsonbCast { from, to } => InvalidJsonbCast(ProtoInvalidJsonbCast {
3103                from: from.into_proto(),
3104                to: to.into_proto(),
3105            }),
3106            EvalError::InvalidRegex(v) => InvalidRegex(v.into_proto()),
3107            EvalError::InvalidRegexFlag(v) => InvalidRegexFlag(v.into_proto()),
3108            EvalError::InvalidParameterValue(v) => InvalidParameterValue(v.into_proto()),
3109            EvalError::InvalidDatePart(part) => InvalidDatePart(part.into_proto()),
3110            EvalError::KeyCannotBeNull => KeyCannotBeNull(()),
3111            EvalError::NegSqrt => NegSqrt(()),
3112            EvalError::NegLimit => NegLimit(()),
3113            EvalError::NullCharacterNotPermitted => NullCharacterNotPermitted(()),
3114            EvalError::UnknownUnits(v) => UnknownUnits(v.into_proto()),
3115            EvalError::UnsupportedUnits(units, typ) => UnsupportedUnits(ProtoUnsupportedUnits {
3116                units: units.into_proto(),
3117                typ: typ.into_proto(),
3118            }),
3119            EvalError::UnterminatedLikeEscapeSequence => UnterminatedLikeEscapeSequence(()),
3120            EvalError::Parse(error) => Parse(error.into_proto()),
3121            EvalError::PrettyError(error) => PrettyError(error.into_proto()),
3122            EvalError::ParseHex(error) => ParseHex(error.into_proto()),
3123            EvalError::Internal(v) => Internal(v.into_proto()),
3124            EvalError::InfinityOutOfDomain(v) => InfinityOutOfDomain(v.into_proto()),
3125            EvalError::NegativeOutOfDomain(v) => NegativeOutOfDomain(v.into_proto()),
3126            EvalError::ZeroOutOfDomain(v) => ZeroOutOfDomain(v.into_proto()),
3127            EvalError::OutOfDomain(lower, upper, id) => OutOfDomain(ProtoOutOfDomain {
3128                lower: Some(lower.into_proto()),
3129                upper: Some(upper.into_proto()),
3130                id: id.into_proto(),
3131            }),
3132            EvalError::ComplexOutOfRange(v) => ComplexOutOfRange(v.into_proto()),
3133            EvalError::MultipleRowsFromSubquery => MultipleRowsFromSubquery(()),
3134            EvalError::Undefined(v) => Undefined(v.into_proto()),
3135            EvalError::LikePatternTooLong => LikePatternTooLong(()),
3136            EvalError::LikeEscapeTooLong => LikeEscapeTooLong(()),
3137            EvalError::StringValueTooLong {
3138                target_type,
3139                length,
3140            } => StringValueTooLong(ProtoStringValueTooLong {
3141                target_type: target_type.into_proto(),
3142                length: length.into_proto(),
3143            }),
3144            EvalError::MultidimensionalArrayRemovalNotSupported => {
3145                MultidimensionalArrayRemovalNotSupported(())
3146            }
3147            EvalError::IncompatibleArrayDimensions { dims } => {
3148                IncompatibleArrayDimensions(ProtoIncompatibleArrayDimensions {
3149                    dims: dims.into_proto(),
3150                })
3151            }
3152            EvalError::TypeFromOid(v) => TypeFromOid(v.into_proto()),
3153            EvalError::InvalidRange(error) => InvalidRange(error.into_proto()),
3154            EvalError::InvalidRoleId(v) => InvalidRoleId(v.into_proto()),
3155            EvalError::InvalidPrivileges(v) => InvalidPrivileges(v.into_proto()),
3156            EvalError::LetRecLimitExceeded(v) => WmrRecursionLimitExceeded(v.into_proto()),
3157            EvalError::MultiDimensionalArraySearch => MultiDimensionalArraySearch(()),
3158            EvalError::MustNotBeNull(v) => MustNotBeNull(v.into_proto()),
3159            EvalError::InvalidIdentifier { ident, detail } => {
3160                InvalidIdentifier(ProtoInvalidIdentifier {
3161                    ident: ident.into_proto(),
3162                    detail: detail.into_proto(),
3163                })
3164            }
3165            EvalError::ArrayFillWrongArraySubscripts => ArrayFillWrongArraySubscripts(()),
3166            EvalError::MaxArraySizeExceeded(max_size) => {
3167                MaxArraySizeExceeded(u64::cast_from(*max_size))
3168            }
3169            EvalError::DateDiffOverflow { unit, a, b } => DateDiffOverflow(ProtoDateDiffOverflow {
3170                unit: unit.into_proto(),
3171                a: a.into_proto(),
3172                b: b.into_proto(),
3173            }),
3174            EvalError::IfNullError(s) => IfNullError(s.into_proto()),
3175            EvalError::LengthTooLarge => LengthTooLarge(()),
3176            EvalError::AclArrayNullElement => AclArrayNullElement(()),
3177            EvalError::MzAclArrayNullElement => MzAclArrayNullElement(()),
3178            EvalError::InvalidIanaTimezoneId(s) => InvalidIanaTimezoneId(s.into_proto()),
3179        };
3180        ProtoEvalError { kind: Some(kind) }
3181    }
3182
3183    fn from_proto(proto: ProtoEvalError) -> Result<Self, TryFromProtoError> {
3184        use proto_eval_error::Kind::*;
3185        match proto.kind {
3186            Some(kind) => match kind {
3187                CharacterNotValidForEncoding(v) => Ok(EvalError::CharacterNotValidForEncoding(v)),
3188                CharacterTooLargeForEncoding(v) => Ok(EvalError::CharacterTooLargeForEncoding(v)),
3189                DateBinOutOfRange(v) => Ok(EvalError::DateBinOutOfRange(v.into())),
3190                DivisionByZero(()) => Ok(EvalError::DivisionByZero),
3191                Unsupported(v) => Ok(EvalError::Unsupported {
3192                    feature: v.feature.into(),
3193                    discussion_no: v.discussion_no.into_rust()?,
3194                }),
3195                FloatOverflow(()) => Ok(EvalError::FloatOverflow),
3196                FloatUnderflow(()) => Ok(EvalError::FloatUnderflow),
3197                NumericFieldOverflow(()) => Ok(EvalError::NumericFieldOverflow),
3198                Float32OutOfRange(val) => Ok(EvalError::Float32OutOfRange(val.value.into())),
3199                Float64OutOfRange(val) => Ok(EvalError::Float64OutOfRange(val.value.into())),
3200                Int16OutOfRange(val) => Ok(EvalError::Int16OutOfRange(val.value.into())),
3201                Int32OutOfRange(val) => Ok(EvalError::Int32OutOfRange(val.value.into())),
3202                Int64OutOfRange(val) => Ok(EvalError::Int64OutOfRange(val.value.into())),
3203                Uint16OutOfRange(val) => Ok(EvalError::UInt16OutOfRange(val.value.into())),
3204                Uint32OutOfRange(val) => Ok(EvalError::UInt32OutOfRange(val.value.into())),
3205                Uint64OutOfRange(val) => Ok(EvalError::UInt64OutOfRange(val.value.into())),
3206                MzTimestampOutOfRange(val) => {
3207                    Ok(EvalError::MzTimestampOutOfRange(val.value.into()))
3208                }
3209                MzTimestampStepOverflow(()) => Ok(EvalError::MzTimestampStepOverflow),
3210                OidOutOfRange(val) => Ok(EvalError::OidOutOfRange(val.value.into())),
3211                IntervalOutOfRange(val) => Ok(EvalError::IntervalOutOfRange(val.value.into())),
3212                TimestampCannotBeNan(()) => Ok(EvalError::TimestampCannotBeNan),
3213                TimestampOutOfRange(()) => Ok(EvalError::TimestampOutOfRange),
3214                DateOutOfRange(()) => Ok(EvalError::DateOutOfRange),
3215                CharOutOfRange(()) => Ok(EvalError::CharOutOfRange),
3216                IndexOutOfRange(v) => Ok(EvalError::IndexOutOfRange {
3217                    provided: v.provided,
3218                    valid_end: v.valid_end,
3219                }),
3220                InvalidBase64Equals(()) => Ok(EvalError::InvalidBase64Equals),
3221                InvalidBase64Symbol(v) => char::from_proto(v).map(EvalError::InvalidBase64Symbol),
3222                InvalidBase64EndSequence(()) => Ok(EvalError::InvalidBase64EndSequence),
3223                InvalidTimezone(v) => Ok(EvalError::InvalidTimezone(v.into())),
3224                InvalidTimezoneInterval(()) => Ok(EvalError::InvalidTimezoneInterval),
3225                InvalidTimezoneConversion(()) => Ok(EvalError::InvalidTimezoneConversion),
3226                InvalidLayer(v) => Ok(EvalError::InvalidLayer {
3227                    max_layer: usize::from_proto(v.max_layer)?,
3228                    val: v.val,
3229                }),
3230                InvalidArray(error) => Ok(EvalError::InvalidArray(error.into_rust()?)),
3231                InvalidEncodingName(v) => Ok(EvalError::InvalidEncodingName(v.into())),
3232                InvalidHashAlgorithm(v) => Ok(EvalError::InvalidHashAlgorithm(v.into())),
3233                InvalidByteSequence(v) => Ok(EvalError::InvalidByteSequence {
3234                    byte_sequence: v.byte_sequence.into(),
3235                    encoding_name: v.encoding_name.into(),
3236                }),
3237                InvalidJsonbCast(v) => Ok(EvalError::InvalidJsonbCast {
3238                    from: v.from.into(),
3239                    to: v.to.into(),
3240                }),
3241                InvalidRegex(v) => Ok(EvalError::InvalidRegex(v.into())),
3242                InvalidRegexFlag(v) => Ok(EvalError::InvalidRegexFlag(char::from_proto(v)?)),
3243                InvalidParameterValue(v) => Ok(EvalError::InvalidParameterValue(v.into())),
3244                InvalidDatePart(part) => Ok(EvalError::InvalidDatePart(part.into())),
3245                KeyCannotBeNull(()) => Ok(EvalError::KeyCannotBeNull),
3246                NegSqrt(()) => Ok(EvalError::NegSqrt),
3247                NegLimit(()) => Ok(EvalError::NegLimit),
3248                NullCharacterNotPermitted(()) => Ok(EvalError::NullCharacterNotPermitted),
3249                UnknownUnits(v) => Ok(EvalError::UnknownUnits(v.into())),
3250                UnsupportedUnits(v) => {
3251                    Ok(EvalError::UnsupportedUnits(v.units.into(), v.typ.into()))
3252                }
3253                UnterminatedLikeEscapeSequence(()) => Ok(EvalError::UnterminatedLikeEscapeSequence),
3254                Parse(error) => Ok(EvalError::Parse(error.into_rust()?)),
3255                ParseHex(error) => Ok(EvalError::ParseHex(error.into_rust()?)),
3256                Internal(v) => Ok(EvalError::Internal(v.into())),
3257                InfinityOutOfDomain(v) => Ok(EvalError::InfinityOutOfDomain(v.into())),
3258                NegativeOutOfDomain(v) => Ok(EvalError::NegativeOutOfDomain(v.into())),
3259                ZeroOutOfDomain(v) => Ok(EvalError::ZeroOutOfDomain(v.into())),
3260                OutOfDomain(v) => Ok(EvalError::OutOfDomain(
3261                    v.lower.into_rust_if_some("ProtoDomainLimit::lower")?,
3262                    v.upper.into_rust_if_some("ProtoDomainLimit::upper")?,
3263                    v.id.into(),
3264                )),
3265                ComplexOutOfRange(v) => Ok(EvalError::ComplexOutOfRange(v.into())),
3266                MultipleRowsFromSubquery(()) => Ok(EvalError::MultipleRowsFromSubquery),
3267                Undefined(v) => Ok(EvalError::Undefined(v.into())),
3268                LikePatternTooLong(()) => Ok(EvalError::LikePatternTooLong),
3269                LikeEscapeTooLong(()) => Ok(EvalError::LikeEscapeTooLong),
3270                StringValueTooLong(v) => Ok(EvalError::StringValueTooLong {
3271                    target_type: v.target_type.into(),
3272                    length: usize::from_proto(v.length)?,
3273                }),
3274                MultidimensionalArrayRemovalNotSupported(()) => {
3275                    Ok(EvalError::MultidimensionalArrayRemovalNotSupported)
3276                }
3277                IncompatibleArrayDimensions(v) => Ok(EvalError::IncompatibleArrayDimensions {
3278                    dims: v.dims.into_rust()?,
3279                }),
3280                TypeFromOid(v) => Ok(EvalError::TypeFromOid(v.into())),
3281                InvalidRange(e) => Ok(EvalError::InvalidRange(e.into_rust()?)),
3282                InvalidRoleId(v) => Ok(EvalError::InvalidRoleId(v.into())),
3283                InvalidPrivileges(v) => Ok(EvalError::InvalidPrivileges(v.into())),
3284                WmrRecursionLimitExceeded(v) => Ok(EvalError::LetRecLimitExceeded(v.into())),
3285                MultiDimensionalArraySearch(()) => Ok(EvalError::MultiDimensionalArraySearch),
3286                MustNotBeNull(v) => Ok(EvalError::MustNotBeNull(v.into())),
3287                InvalidIdentifier(v) => Ok(EvalError::InvalidIdentifier {
3288                    ident: v.ident.into(),
3289                    detail: v.detail.into_rust()?,
3290                }),
3291                ArrayFillWrongArraySubscripts(()) => Ok(EvalError::ArrayFillWrongArraySubscripts),
3292                MaxArraySizeExceeded(max_size) => {
3293                    Ok(EvalError::MaxArraySizeExceeded(usize::cast_from(max_size)))
3294                }
3295                DateDiffOverflow(v) => Ok(EvalError::DateDiffOverflow {
3296                    unit: v.unit.into(),
3297                    a: v.a.into(),
3298                    b: v.b.into(),
3299                }),
3300                IfNullError(v) => Ok(EvalError::IfNullError(v.into())),
3301                LengthTooLarge(()) => Ok(EvalError::LengthTooLarge),
3302                AclArrayNullElement(()) => Ok(EvalError::AclArrayNullElement),
3303                MzAclArrayNullElement(()) => Ok(EvalError::MzAclArrayNullElement),
3304                InvalidIanaTimezoneId(s) => Ok(EvalError::InvalidIanaTimezoneId(s.into())),
3305                PrettyError(s) => Ok(EvalError::PrettyError(s.into())),
3306            },
3307            None => Err(TryFromProtoError::missing_field("ProtoEvalError::kind")),
3308        }
3309    }
3310}
3311
3312impl RustType<ProtoDims> for (usize, usize) {
3313    fn into_proto(&self) -> ProtoDims {
3314        ProtoDims {
3315            f0: self.0.into_proto(),
3316            f1: self.1.into_proto(),
3317        }
3318    }
3319
3320    fn from_proto(proto: ProtoDims) -> Result<Self, TryFromProtoError> {
3321        Ok((proto.f0.into_rust()?, proto.f1.into_rust()?))
3322    }
3323}
3324
3325#[cfg(test)]
3326mod tests {
3327    use super::*;
3328
3329    #[mz_ore::test]
3330    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
3331    fn test_reduce() {
3332        let relation_type: Vec<ReprColumnType> = vec![
3333            ReprScalarType::Int64.nullable(true),
3334            ReprScalarType::Int64.nullable(true),
3335            ReprScalarType::Int64.nullable(false),
3336        ]
3337        .into_iter()
3338        .collect();
3339        let col = MirScalarExpr::column;
3340        let int64_typ = ReprScalarType::Int64;
3341        let err = |e| MirScalarExpr::literal(Err(e), int64_typ.clone());
3342        let lit = |i| MirScalarExpr::literal_ok(Datum::Int64(i), int64_typ.clone());
3343        let null = || MirScalarExpr::literal_null(int64_typ.clone());
3344
3345        struct TestCase {
3346            input: MirScalarExpr,
3347            output: MirScalarExpr,
3348        }
3349
3350        let test_cases = vec![
3351            TestCase {
3352                input: MirScalarExpr::call_variadic(Coalesce, vec![lit(1)]),
3353                output: lit(1),
3354            },
3355            TestCase {
3356                input: MirScalarExpr::call_variadic(Coalesce, vec![lit(1), lit(2)]),
3357                output: lit(1),
3358            },
3359            TestCase {
3360                input: MirScalarExpr::call_variadic(Coalesce, vec![null(), lit(2), null()]),
3361                output: lit(2),
3362            },
3363            TestCase {
3364                input: MirScalarExpr::call_variadic(
3365                    Coalesce,
3366                    vec![null(), col(0), null(), col(1), lit(2), lit(3)],
3367                ),
3368                output: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(1), lit(2)]),
3369            },
3370            TestCase {
3371                input: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(2), col(1)]),
3372                output: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(2)]),
3373            },
3374            TestCase {
3375                input: MirScalarExpr::call_variadic(
3376                    Coalesce,
3377                    vec![lit(1), err(EvalError::DivisionByZero)],
3378                ),
3379                output: lit(1),
3380            },
3381            TestCase {
3382                input: MirScalarExpr::call_variadic(
3383                    Coalesce,
3384                    vec![
3385                        null(),
3386                        err(EvalError::DivisionByZero),
3387                        err(EvalError::NumericFieldOverflow),
3388                    ],
3389                ),
3390                output: err(EvalError::DivisionByZero),
3391            },
3392        ];
3393
3394        for tc in test_cases {
3395            let mut actual = tc.input.clone();
3396            actual.reduce(&relation_type);
3397            assert!(
3398                actual == tc.output,
3399                "input: {}\nactual: {}\nexpected: {}",
3400                tc.input,
3401                actual,
3402                tc.output
3403            );
3404        }
3405    }
3406}