Skip to main content

mz_expr/
scalar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::ops::BitOrAssign;
12use std::sync::Arc;
13use std::{fmt, mem};
14
15use itertools::Itertools;
16use mz_lowertest::MzReflect;
17use mz_ore::cast::CastFrom;
18use mz_ore::collections::CollectionExt;
19use mz_ore::iter::IteratorExt;
20use mz_ore::soft_assert_or_log;
21use mz_ore::stack::RecursionLimitError;
22use mz_ore::str::StrExt;
23use mz_ore::treat_as_equal::TreatAsEqual;
24use mz_ore::vec::swap_remove_multiple;
25use mz_pgrepr::TypeFromOidError;
26use mz_pgtz::timezone::TimezoneSpec;
27use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
28use mz_repr::adt::array::InvalidArrayError;
29use mz_repr::adt::date::DateError;
30use mz_repr::adt::datetime::DateTimeUnits;
31use mz_repr::adt::range::InvalidRangeError;
32use mz_repr::adt::regex::{Regex, RegexCompilationError};
33use mz_repr::adt::timestamp::TimestampError;
34use mz_repr::strconv::{ParseError, ParseHexError};
35use mz_repr::{Datum, ReprColumnType, ReprScalarType, Row, RowArena, SqlColumnType, SqlScalarType};
36
37use proptest::prelude::*;
38use proptest_derive::Arbitrary;
39use serde::{Deserialize, Serialize};
40
41use crate::explain::{HumanizedExplain, HumanizerMode};
42use crate::scalar::func::format::DateTimeFormat;
43use crate::scalar::func::variadic::{
44    And, Coalesce, ListCreate, ListIndex, Or, RegexpMatch, RegexpReplace, RegexpSplitToArray,
45};
46use crate::scalar::func::{
47    BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, parse_timezone,
48    regexp_replace_parse_flags,
49};
50use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims;
51use crate::visit::{Visit, VisitChildren};
52
53pub mod func;
54pub mod like_pattern;
55
56include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs"));
57
58#[derive(
59    Clone,
60    PartialEq,
61    Eq,
62    PartialOrd,
63    Ord,
64    Hash,
65    Serialize,
66    Deserialize,
67    MzReflect
68)]
69pub enum MirScalarExpr {
70    /// A column of the input row
71    Column(usize, TreatAsEqual<Option<Arc<str>>>),
72    /// A literal value.
73    /// (Stored as a row, because we can't own a Datum)
74    Literal(Result<Row, EvalError>, ReprColumnType),
75    /// A call to an unmaterializable function.
76    ///
77    /// These functions cannot be evaluated by `MirScalarExpr::eval`. They must
78    /// be transformed away by a higher layer.
79    CallUnmaterializable(UnmaterializableFunc),
80    /// A function call that takes one expression as an argument.
81    CallUnary {
82        func: UnaryFunc,
83        expr: Box<MirScalarExpr>,
84    },
85    /// A function call that takes two expressions as arguments.
86    CallBinary {
87        func: BinaryFunc,
88        expr1: Box<MirScalarExpr>,
89        expr2: Box<MirScalarExpr>,
90    },
91    /// A function call that takes an arbitrary number of arguments.
92    CallVariadic {
93        func: VariadicFunc,
94        exprs: Vec<MirScalarExpr>,
95    },
96    /// Conditionally evaluated expressions.
97    ///
98    /// It is important that `then` and `els` only be evaluated if
99    /// `cond` is true or not, respectively. This is the only way
100    /// users can guard execution (other logical operator do not
101    /// short-circuit) and we need to preserve that.
102    If {
103        cond: Box<MirScalarExpr>,
104        then: Box<MirScalarExpr>,
105        els: Box<MirScalarExpr>,
106    },
107}
108
109// We need a custom Debug because we don't want to show `None` for name information.
110// Sadly, the `derivative` crate doesn't support this use case.
111impl std::fmt::Debug for MirScalarExpr {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            MirScalarExpr::Column(i, TreatAsEqual(Some(name))) => {
115                write!(f, "Column({i}, {name:?})")
116            }
117            MirScalarExpr::Column(i, TreatAsEqual(None)) => write!(f, "Column({i})"),
118            MirScalarExpr::Literal(lit, typ) => write!(f, "Literal({lit:?}, {typ:?})"),
119            MirScalarExpr::CallUnmaterializable(func) => {
120                write!(f, "CallUnmaterializable({func:?})")
121            }
122            MirScalarExpr::CallUnary { func, expr } => {
123                write!(f, "CallUnary({func:?}, {expr:?})")
124            }
125            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
126                write!(f, "CallBinary({func:?}, {expr1:?}, {expr2:?})")
127            }
128            MirScalarExpr::CallVariadic { func, exprs } => {
129                write!(f, "CallVariadic({func:?}, {exprs:?})")
130            }
131            MirScalarExpr::If { cond, then, els } => {
132                write!(f, "If({cond:?}, {then:?}, {els:?})")
133            }
134        }
135    }
136}
137
138impl MirScalarExpr {
139    pub fn columns(is: &[usize]) -> Vec<MirScalarExpr> {
140        is.iter().map(|i| MirScalarExpr::column(*i)).collect()
141    }
142
143    pub fn column(column: usize) -> Self {
144        MirScalarExpr::Column(column, TreatAsEqual(None))
145    }
146
147    pub fn named_column(column: usize, name: Arc<str>) -> Self {
148        MirScalarExpr::Column(column, TreatAsEqual(Some(name)))
149    }
150
151    pub fn literal(res: Result<Datum, EvalError>, typ: ReprScalarType) -> Self {
152        let typ = ReprColumnType {
153            scalar_type: typ,
154            nullable: matches!(res, Ok(Datum::Null)),
155        };
156        let row = res.map(|datum| Row::pack_slice(&[datum]));
157        MirScalarExpr::Literal(row, typ)
158    }
159
160    pub fn literal_ok(datum: Datum, typ: ReprScalarType) -> Self {
161        MirScalarExpr::literal(Ok(datum), typ)
162    }
163
164    /// Constructs a `MirScalarExpr::Literal` from a pre-packed `Row`
165    /// containing a single datum and a `ReprScalarType`. Nullability is
166    /// derived by inspecting the first datum in the row.
167    pub fn literal_from_single_element_row(row: Row, typ: ReprScalarType) -> Self {
168        soft_assert_or_log!(
169            row.iter().count() == 1,
170            "literal_from_row called with a Row containing {} datums",
171            row.iter().count()
172        );
173        let nullable = row.unpack_first() == Datum::Null;
174        let typ = ReprColumnType {
175            scalar_type: typ,
176            nullable,
177        };
178        MirScalarExpr::Literal(Ok(row), typ)
179    }
180
181    pub fn literal_null(typ: ReprScalarType) -> Self {
182        MirScalarExpr::literal_ok(Datum::Null, typ)
183    }
184
185    pub fn literal_false() -> Self {
186        MirScalarExpr::literal_ok(Datum::False, ReprScalarType::Bool)
187    }
188
189    pub fn literal_true() -> Self {
190        MirScalarExpr::literal_ok(Datum::True, ReprScalarType::Bool)
191    }
192
193    pub fn call_unary<U: Into<UnaryFunc>>(self, func: U) -> Self {
194        MirScalarExpr::CallUnary {
195            func: func.into(),
196            expr: Box::new(self),
197        }
198    }
199
200    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
201        MirScalarExpr::CallBinary {
202            func: func.into(),
203            expr1: Box::new(self),
204            expr2: Box::new(other),
205        }
206    }
207
208    /// Call function `func` on `exprs`.
209    pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
210        MirScalarExpr::CallVariadic {
211            func: func.into(),
212            exprs,
213        }
214    }
215
216    pub fn if_then_else(self, t: Self, f: Self) -> Self {
217        MirScalarExpr::If {
218            cond: Box::new(self),
219            then: Box::new(t),
220            els: Box::new(f),
221        }
222    }
223
224    pub fn or(self, other: Self) -> Self {
225        MirScalarExpr::call_variadic(Or, vec![self, other])
226    }
227
228    pub fn and(self, other: Self) -> Self {
229        MirScalarExpr::call_variadic(And, vec![self, other])
230    }
231
232    pub fn not(self) -> Self {
233        self.call_unary(UnaryFunc::Not(func::Not))
234    }
235
236    pub fn call_is_null(self) -> Self {
237        self.call_unary(UnaryFunc::IsNull(func::IsNull))
238    }
239
240    /// Match AND or OR on self and get the args. If no match, then interpret self as if it were
241    /// wrapped in a 1-arg AND/OR.
242    pub fn and_or_args(&self, func_to_match: VariadicFunc) -> Vec<MirScalarExpr> {
243        assert!(func_to_match == Or.into() || func_to_match == And.into());
244        match self {
245            MirScalarExpr::CallVariadic { func, exprs } if *func == func_to_match => exprs.clone(),
246            _ => vec![self.clone()],
247        }
248    }
249
250    /// Try to match a literal equality involving the given expression on one side.
251    /// Return the (non-null) literal and a bool that indicates whether an inversion was needed.
252    ///
253    /// More specifically:
254    /// If `self` is an equality with a `null` literal on any side, then the match fails!
255    /// Otherwise: for a given `expr`, if `self` is `<expr> = <literal>` or `<literal> = <expr>`
256    /// then return `Some((<literal>, false))`. In addition to just trying to match `<expr>` as it
257    /// is, we also try to remove an invertible function call (such as a cast). If the match
258    /// succeeds with the inversion, then return `Some((<inverted-literal>, true))`. For more
259    /// details on the inversion, see `invert_casts_on_expr_eq_literal_inner`.
260    pub fn expr_eq_literal(&self, expr: &MirScalarExpr) -> Option<(Row, bool)> {
261        if let MirScalarExpr::CallBinary {
262            func: BinaryFunc::Eq(_),
263            expr1,
264            expr2,
265        } = self
266        {
267            if expr1.is_literal_null() || expr2.is_literal_null() {
268                return None;
269            }
270            if let Some(Ok(lit)) = expr1.as_literal_owned() {
271                return Self::expr_eq_literal_inner(expr, lit, expr1, expr2);
272            }
273            if let Some(Ok(lit)) = expr2.as_literal_owned() {
274                return Self::expr_eq_literal_inner(expr, lit, expr2, expr1);
275            }
276        }
277        None
278    }
279
280    fn expr_eq_literal_inner(
281        expr_to_match: &MirScalarExpr,
282        literal: Row,
283        literal_expr: &MirScalarExpr,
284        other_side: &MirScalarExpr,
285    ) -> Option<(Row, bool)> {
286        if other_side == expr_to_match {
287            return Some((literal, false));
288        } else {
289            // expr didn't exactly match. See if we can match it by inverse-casting.
290            let (cast_removed, inv_cast_lit) =
291                Self::invert_casts_on_expr_eq_literal_inner(other_side, literal_expr);
292            if &cast_removed == expr_to_match {
293                if let Some(Ok(inv_cast_lit_row)) = inv_cast_lit.as_literal_owned() {
294                    return Some((inv_cast_lit_row, true));
295                }
296            }
297        }
298        None
299    }
300
301    /// If `self` is `<expr> = <literal>` or `<literal> = <expr>` then
302    /// return `<expr>`. It also tries to remove a cast (or other invertible function call) from
303    /// `<expr>` before returning it, see `invert_casts_on_expr_eq_literal_inner`.
304    pub fn any_expr_eq_literal(&self) -> Option<MirScalarExpr> {
305        if let MirScalarExpr::CallBinary {
306            func: BinaryFunc::Eq(_),
307            expr1,
308            expr2,
309        } = self
310        {
311            if expr1.is_literal() {
312                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
313                return Some(expr);
314            }
315            if expr2.is_literal() {
316                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
317                return Some(expr);
318            }
319        }
320        None
321    }
322
323    /// If the given `MirScalarExpr` is a literal equality where one side is an invertible function
324    /// call, then calls the inverse function on both sides of the equality and returns the modified
325    /// version of the given `MirScalarExpr`. Otherwise, it returns the original expression.
326    /// For more details, see `invert_casts_on_expr_eq_literal_inner`.
327    pub fn invert_casts_on_expr_eq_literal(&self) -> MirScalarExpr {
328        if let MirScalarExpr::CallBinary {
329            func: BinaryFunc::Eq(_),
330            expr1,
331            expr2,
332        } = self
333        {
334            if expr1.is_literal() {
335                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
336                return literal.call_binary(expr, func::Eq);
337            }
338            if expr2.is_literal() {
339                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
340                return literal.call_binary(expr, func::Eq);
341            }
342            // Note: The above return statements should be consistent in whether they put the
343            // literal in expr1 or expr2, for the deduplication in CanonicalizeMfp to work.
344        }
345        self.clone()
346    }
347
348    /// Given an `<expr>` and a `<literal>` that were taken out from `<expr> = <literal>` or
349    /// `<literal> = <expr>`, it tries to simplify the equality by applying the inverse function of
350    /// the outermost function call of `<expr>` (if exists):
351    ///
352    /// `<literal> = func(<inner_expr>)`, where `func` is invertible
353    ///  -->
354    /// `<func^-1(literal)> = <inner_expr>`
355    /// if `func^-1(literal)` doesn't error out, and both `func` and `func^-1` preserve uniqueness.
356    ///
357    /// The return value is the `<inner_expr>` and the literal value that we get by applying the
358    /// inverse function.
359    fn invert_casts_on_expr_eq_literal_inner(
360        expr: &MirScalarExpr,
361        literal: &MirScalarExpr,
362    ) -> (MirScalarExpr, MirScalarExpr) {
363        assert!(matches!(literal, MirScalarExpr::Literal(..)));
364
365        let temp_storage = &RowArena::new();
366        let eval = |e: &MirScalarExpr| {
367            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
368        };
369
370        if let MirScalarExpr::CallUnary {
371            func,
372            expr: inner_expr,
373        } = expr
374        {
375            if let Some(inverse_func) = func.inverse() {
376                // We don't want to remove a function call that doesn't preserve uniqueness, e.g.,
377                // if `f` is a float, we don't want to inverse-cast `f::INT = 0`, because the
378                // inserted int-to-float cast wouldn't be able to invert the rounding.
379                // Also, we don't want to insert a function call that doesn't preserve
380                // uniqueness. E.g., if `a` has an integer type, we don't want to do
381                // a surprise rounding for `WHERE a = 3.14`.
382                if func.preserves_uniqueness() && inverse_func.preserves_uniqueness() {
383                    let lit_inv = eval(&MirScalarExpr::CallUnary {
384                        func: inverse_func,
385                        expr: Box::new(literal.clone()),
386                    });
387                    // The evaluation can error out, e.g., when casting a too large int32 to int16.
388                    // This case is handled by `impossible_literal_equality_because_types`.
389                    if !lit_inv.is_literal_err() {
390                        return (*inner_expr.clone(), lit_inv);
391                    }
392                }
393            }
394        }
395        (expr.clone(), literal.clone())
396    }
397
398    /// Tries to remove a cast (or other invertible function) in the same way as
399    /// `invert_casts_on_expr_eq_literal`, but if calling the inverse function fails on the literal,
400    /// then it deems the equality to be impossible. For example if `a` is a smallint column, then
401    /// it catches `a::integer = 1000000` to be an always false predicate (where the `::integer`
402    /// could have been inserted implicitly).
403    pub fn impossible_literal_equality_because_types(&self) -> bool {
404        if let MirScalarExpr::CallBinary {
405            func: BinaryFunc::Eq(_),
406            expr1,
407            expr2,
408        } = self
409        {
410            if expr1.is_literal() {
411                return Self::impossible_literal_equality_because_types_inner(expr1, expr2);
412            }
413            if expr2.is_literal() {
414                return Self::impossible_literal_equality_because_types_inner(expr2, expr1);
415            }
416        }
417        false
418    }
419
420    fn impossible_literal_equality_because_types_inner(
421        literal: &MirScalarExpr,
422        other_side: &MirScalarExpr,
423    ) -> bool {
424        assert!(matches!(literal, MirScalarExpr::Literal(..)));
425
426        let temp_storage = &RowArena::new();
427        let eval = |e: &MirScalarExpr| {
428            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
429        };
430
431        if let MirScalarExpr::CallUnary { func, .. } = other_side {
432            if let Some(inverse_func) = func.inverse() {
433                if inverse_func.preserves_uniqueness()
434                    && eval(&MirScalarExpr::CallUnary {
435                        func: inverse_func,
436                        expr: Box::new(literal.clone()),
437                    })
438                    .is_literal_err()
439                {
440                    return true;
441                }
442            }
443        }
444
445        false
446    }
447
448    /// Determines if `self` is
449    /// `<expr> < <literal>` or
450    /// `<expr> > <literal>` or
451    /// `<literal> < <expr>` or
452    /// `<literal> > <expr>` or
453    /// `<expr> <= <literal>` or
454    /// `<expr> >= <literal>` or
455    /// `<literal> <= <expr>` or
456    /// `<literal> >= <expr>`.
457    pub fn any_expr_ineq_literal(&self) -> bool {
458        match self {
459            MirScalarExpr::CallBinary {
460                func:
461                    BinaryFunc::Lt(_) | BinaryFunc::Lte(_) | BinaryFunc::Gt(_) | BinaryFunc::Gte(_),
462                expr1,
463                expr2,
464            } => expr1.is_literal() || expr2.is_literal(),
465            _ => false,
466        }
467    }
468
469    /// Rewrites column indices with their value in `permutation`.
470    ///
471    /// This method is applicable even when `permutation` is not a
472    /// strict permutation, and it only needs to have entries for
473    /// each column referenced in `self`.
474    pub fn permute(&mut self, permutation: &[usize]) {
475        self.visit_columns(|c| *c = permutation[*c]);
476    }
477
478    /// Rewrites column indices with their value in `permutation`.
479    ///
480    /// This method is applicable even when `permutation` is not a
481    /// strict permutation, and it only needs to have entries for
482    /// each column referenced in `self`.
483    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
484        self.visit_columns(|c| *c = permutation[c]);
485    }
486
487    /// Visits each column reference and applies `action` to the column.
488    ///
489    /// Useful for remapping columns, or for collecting expression support.
490    pub fn visit_columns<F>(&mut self, mut action: F)
491    where
492        F: FnMut(&mut usize),
493    {
494        self.visit_pre_mut(|e| {
495            if let MirScalarExpr::Column(col, _) = e {
496                action(col);
497            }
498        });
499    }
500
501    pub fn support(&self) -> BTreeSet<usize> {
502        let mut support = BTreeSet::new();
503        self.support_into(&mut support);
504        support
505    }
506
507    pub fn support_into(&self, support: &mut BTreeSet<usize>) {
508        self.visit_pre(|e| {
509            if let MirScalarExpr::Column(i, _) = e {
510                support.insert(*i);
511            }
512        });
513    }
514
515    pub fn take(&mut self) -> Self {
516        mem::replace(self, MirScalarExpr::literal_null(ReprScalarType::String))
517    }
518
519    /// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
520    /// Otherwise, it returns None.
521    pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
522        if let MirScalarExpr::Literal(lit, _column_type) = self {
523            Some(lit.as_ref().map(|row| row.unpack_first()))
524        } else {
525            None
526        }
527    }
528
529    /// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
530    /// literal's Datum only if the expression is a literal, and it's not a literal error.
531    pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
532        self.as_literal().map(|eval_err| eval_err.ok()).flatten()
533    }
534
535    pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
536        if let MirScalarExpr::Literal(lit, _column_type) = self {
537            Some(lit.clone())
538        } else {
539            None
540        }
541    }
542
543    /// Returns a reference to the `Row` if the expression is a non-NULL `Ok` literal.
544    pub fn as_literal_non_null_row(&self) -> Option<&Row> {
545        if let MirScalarExpr::Literal(Ok(row), _) = self {
546            if !row.unpack_first().is_null() {
547                return Some(row);
548            }
549        }
550        None
551    }
552
553    pub fn as_literal_str(&self) -> Option<&str> {
554        match self.as_literal() {
555            Some(Ok(Datum::String(s))) => Some(s),
556            _ => None,
557        }
558    }
559
560    pub fn as_literal_int64(&self) -> Option<i64> {
561        match self.as_literal() {
562            Some(Ok(Datum::Int64(i))) => Some(i),
563            _ => None,
564        }
565    }
566
567    pub fn as_literal_err(&self) -> Option<&EvalError> {
568        self.as_literal().and_then(|lit| lit.err())
569    }
570
571    pub fn is_literal(&self) -> bool {
572        matches!(self, MirScalarExpr::Literal(_, _))
573    }
574
575    pub fn is_literal_true(&self) -> bool {
576        Some(Ok(Datum::True)) == self.as_literal()
577    }
578
579    pub fn is_literal_false(&self) -> bool {
580        Some(Ok(Datum::False)) == self.as_literal()
581    }
582
583    pub fn is_literal_null(&self) -> bool {
584        Some(Ok(Datum::Null)) == self.as_literal()
585    }
586
587    pub fn is_literal_ok(&self) -> bool {
588        matches!(self, MirScalarExpr::Literal(Ok(_), _typ))
589    }
590
591    pub fn is_literal_err(&self) -> bool {
592        matches!(self, MirScalarExpr::Literal(Err(_), _typ))
593    }
594
595    pub fn is_column(&self) -> bool {
596        matches!(self, MirScalarExpr::Column(_col, _name))
597    }
598
599    pub fn is_error_if_null(&self) -> bool {
600        matches!(
601            self,
602            Self::CallVariadic {
603                func: VariadicFunc::ErrorIfNull(_),
604                ..
605            }
606        )
607    }
608
609    /// If `self` expresses a temporal filter, normalize it to start with `mz_now()` and return
610    /// references.
611    ///
612    /// A temporal filter is an expression of the form `mz_now() <BINOP> <EXPR>`,
613    /// for a restricted set of `BINOP` and `EXPR` that do not themselves contain `mz_now()`.
614    /// Expressions may conform to this once their expressions are swapped.
615    ///
616    /// If the expression is not a temporal filter, it will be unchanged, and the reason for why
617    /// it's not a temporal filter is returned as a string.
618    pub fn as_mut_temporal_filter(&mut self) -> Result<(&BinaryFunc, &mut MirScalarExpr), String> {
619        if !self.contains_temporal() {
620            return Err("Does not involve mz_now()".to_string());
621        }
622        // Supported temporal predicates are exclusively binary operators.
623        if let MirScalarExpr::CallBinary { func, expr1, expr2 } = self {
624            // Attempt to put `LogicalTimestamp` in the first argument position.
625            if !expr1.contains_temporal()
626                && **expr2 == MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
627            {
628                let new_func = match func {
629                    BinaryFunc::Eq(_) => func::Eq.into(),
630                    BinaryFunc::Lt(_) => func::Gt.into(),
631                    BinaryFunc::Lte(_) => func::Gte.into(),
632                    BinaryFunc::Gt(_) => func::Lt.into(),
633                    BinaryFunc::Gte(_) => func::Lte.into(),
634                    x => {
635                        return Err(format!("Unsupported binary temporal operation: {:?}", x));
636                    }
637                };
638                std::mem::swap(expr1, expr2);
639                *func = new_func;
640            }
641
642            // Error if MLT is referenced in an unsupported position.
643            if expr2.contains_temporal()
644                || **expr1 != MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
645            {
646                let mode = HumanizedExplain::new(false); // no redaction
647                let bad_expr = MirScalarExpr::CallBinary {
648                    func: func.clone(),
649                    expr1: expr1.clone(),
650                    expr2: expr2.clone(),
651                };
652                return Err(format!(
653                    "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a mz_timestamp-castable expression. Expression found: {}",
654                    mode.expr(&bad_expr, None),
655                ));
656            }
657
658            Ok((&*func, expr2))
659        } else {
660            let mode = HumanizedExplain::new(false); // no redaction
661            Err(format!(
662                "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a non-temporal expression of mz_timestamp-castable type. Expression found: {}",
663                mode.expr(self, None),
664            ))
665        }
666    }
667
668    #[deprecated = "Use `might_error` instead"]
669    pub fn contains_error_if_null(&self) -> bool {
670        let mut worklist = vec![self];
671        while let Some(expr) = worklist.pop() {
672            if expr.is_error_if_null() {
673                return true;
674            }
675            worklist.extend(expr.children());
676        }
677        false
678    }
679
680    pub fn contains_err(&self) -> bool {
681        let mut worklist = vec![self];
682        while let Some(expr) = worklist.pop() {
683            if expr.is_literal_err() {
684                return true;
685            }
686            worklist.extend(expr.children());
687        }
688        false
689    }
690
691    /// A very crude approximation for scalar expressions that might produce an
692    /// error.
693    ///
694    /// Currently, this is restricted only to expressions that either contain a
695    /// literal error or a [`VariadicFunc::ErrorIfNull`] call.
696    pub fn might_error(&self) -> bool {
697        let mut worklist = vec![self];
698        while let Some(expr) = worklist.pop() {
699            if expr.is_literal_err() || expr.is_error_if_null() {
700                return true;
701            }
702            worklist.extend(expr.children());
703        }
704        false
705    }
706
707    /// If self is a column, return the column index, otherwise `None`.
708    pub fn as_column(&self) -> Option<usize> {
709        if let MirScalarExpr::Column(c, _) = self {
710            Some(*c)
711        } else {
712            None
713        }
714    }
715
716    /// Reduces a complex expression where possible.
717    ///
718    /// This function uses nullability information present in `column_types`,
719    /// and the result may only continue to be a correct transformation as
720    /// long as this information continues to hold (nullability may not hold
721    /// as expressions migrate around).
722    ///
723    /// (If you'd like to not use nullability information here, then you can
724    /// tweak the nullabilities in `column_types` before passing it to this
725    /// function, see e.g. in `EquivalenceClasses::minimize`.)
726    ///
727    /// Also performs partial canonicalization on the expression.
728    ///
729    /// ```rust
730    /// use mz_expr::MirScalarExpr;
731    /// use mz_repr::{ReprColumnType, Datum, SqlScalarType};
732    ///
733    /// let expr_0 = MirScalarExpr::column(0);
734    /// let expr_t = MirScalarExpr::literal_true();
735    /// let expr_f = MirScalarExpr::literal_false();
736    ///
737    /// let mut test =
738    /// expr_t
739    ///     .clone()
740    ///     .and(expr_f.clone())
741    ///     .if_then_else(expr_0, expr_t.clone());
742    ///
743    /// let input_type = vec![ReprColumnType::from(&SqlScalarType::Int32.nullable(false))];
744    /// test.reduce(&input_type);
745    /// assert_eq!(test, expr_t);
746    /// ```
747    /// Reduce the expression to a simpler form.
748    pub fn reduce(&mut self, column_types: &[ReprColumnType]) {
749        let temp_storage = &RowArena::new();
750        let eval = |e: &MirScalarExpr| {
751            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type)
752        };
753
754        // Simplifications run in a loop until `self` no longer changes.
755        let mut old_self = MirScalarExpr::column(0);
756        while old_self != *self {
757            old_self = self.clone();
758            #[allow(deprecated)]
759            self.visit_mut_pre_post_nolimit(
760                &mut |e| {
761                    match e {
762                        MirScalarExpr::CallUnary { func, expr } => {
763                            if *func == UnaryFunc::IsNull(func::IsNull) {
764                                if !expr.typ(column_types).nullable {
765                                    *e = MirScalarExpr::literal_false();
766                                } else {
767                                    // Try to at least decompose IsNull into a disjunction
768                                    // of simpler IsNull subexpressions.
769                                    if let Some(expr) = expr.decompose_is_null() {
770                                        *e = expr
771                                    }
772                                }
773                            } else if *func == UnaryFunc::Not(func::Not) {
774                                // Push down not expressions
775                                match &mut **expr {
776                                    // Two negates cancel each other out.
777                                    MirScalarExpr::CallUnary {
778                                        expr: inner_expr,
779                                        func: UnaryFunc::Not(func::Not),
780                                    } => *e = inner_expr.take(),
781                                    // Transforms `NOT(a <op> b)` to `a negate(<op>) b`
782                                    // if a negation exists.
783                                    MirScalarExpr::CallBinary { expr1, expr2, func } => {
784                                        if let Some(negated_func) = func.negate() {
785                                            *e = MirScalarExpr::CallBinary {
786                                                expr1: Box::new(expr1.take()),
787                                                expr2: Box::new(expr2.take()),
788                                                func: negated_func,
789                                            }
790                                        }
791                                    }
792                                    MirScalarExpr::CallVariadic { .. } => {
793                                        e.demorgans();
794                                    }
795                                    _ => {}
796                                }
797                            }
798                        }
799                        _ => {}
800                    };
801                    None
802                },
803                &mut |e| match e {
804                    // Evaluate and pull up constants
805                    MirScalarExpr::Column(_, _)
806                    | MirScalarExpr::Literal(_, _)
807                    | MirScalarExpr::CallUnmaterializable(_) => (),
808                    MirScalarExpr::CallUnary { func, expr } => {
809                        if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
810                            *e = eval(e);
811                        } else if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
812                            if let MirScalarExpr::CallVariadic {
813                                func: VariadicFunc::RecordCreate(..),
814                                exprs,
815                            } = &mut **expr
816                            {
817                                *e = exprs.swap_remove(i);
818                            }
819                        }
820                    }
821                    MirScalarExpr::CallBinary { func, expr1, expr2 } => {
822                        if expr1.is_literal() && expr2.is_literal() {
823                            *e = eval(e);
824                        } else if (expr1.is_literal_null() || expr2.is_literal_null())
825                            && func.propagates_nulls()
826                        {
827                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
828                        } else if let Some(err) = expr1.as_literal_err() {
829                            *e = MirScalarExpr::literal(
830                                Err(err.clone()),
831                                e.typ(column_types).scalar_type,
832                            );
833                        } else if let Some(err) = expr2.as_literal_err() {
834                            *e = MirScalarExpr::literal(
835                                Err(err.clone()),
836                                e.typ(column_types).scalar_type,
837                            );
838                        } else if let BinaryFunc::IsLikeMatchCaseInsensitive(_) = func {
839                            if expr2.is_literal() {
840                                // We can at least precompile the regex.
841                                let pattern = expr2.as_literal_str().unwrap();
842                                *e = match like_pattern::compile(pattern, true) {
843                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
844                                        func::IsLikeMatch(matcher),
845                                    )),
846                                    Err(err) => MirScalarExpr::literal(
847                                        Err(err),
848                                        e.typ(column_types).scalar_type,
849                                    ),
850                                };
851                            }
852                        } else if let BinaryFunc::IsLikeMatchCaseSensitive(_) = func {
853                            if expr2.is_literal() {
854                                // We can at least precompile the regex.
855                                let pattern = expr2.as_literal_str().unwrap();
856                                *e = match like_pattern::compile(pattern, false) {
857                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
858                                        func::IsLikeMatch(matcher),
859                                    )),
860                                    Err(err) => MirScalarExpr::literal(
861                                        Err(err),
862                                        e.typ(column_types).scalar_type,
863                                    ),
864                                };
865                            }
866                        } else if matches!(
867                            func,
868                            BinaryFunc::IsRegexpMatchCaseSensitive(_)
869                                | BinaryFunc::IsRegexpMatchCaseInsensitive(_)
870                        ) {
871                            let case_insensitive =
872                                matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
873                            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
874                                *e = match Regex::new(
875                                    row.unpack_first().unwrap_str(),
876                                    case_insensitive,
877                                ) {
878                                    Ok(regex) => expr1.take().call_unary(UnaryFunc::IsRegexpMatch(
879                                        func::IsRegexpMatch(regex),
880                                    )),
881                                    Err(err) => MirScalarExpr::literal(
882                                        Err(err.into()),
883                                        e.typ(column_types).scalar_type,
884                                    ),
885                                };
886                            }
887                        } else if let BinaryFunc::ExtractInterval(_) = *func
888                            && expr1.is_literal()
889                        {
890                            let units = expr1.as_literal_str().unwrap();
891                            *e = match units.parse::<DateTimeUnits>() {
892                                Ok(units) => MirScalarExpr::CallUnary {
893                                    func: UnaryFunc::ExtractInterval(func::ExtractInterval(units)),
894                                    expr: Box::new(expr2.take()),
895                                },
896                                Err(_) => MirScalarExpr::literal(
897                                    Err(EvalError::UnknownUnits(units.into())),
898                                    e.typ(column_types).scalar_type,
899                                ),
900                            }
901                        } else if let BinaryFunc::ExtractTime(_) = *func
902                            && expr1.is_literal()
903                        {
904                            let units = expr1.as_literal_str().unwrap();
905                            *e = match units.parse::<DateTimeUnits>() {
906                                Ok(units) => MirScalarExpr::CallUnary {
907                                    func: UnaryFunc::ExtractTime(func::ExtractTime(units)),
908                                    expr: Box::new(expr2.take()),
909                                },
910                                Err(_) => MirScalarExpr::literal(
911                                    Err(EvalError::UnknownUnits(units.into())),
912                                    e.typ(column_types).scalar_type,
913                                ),
914                            }
915                        } else if let BinaryFunc::ExtractTimestamp(_) = *func
916                            && expr1.is_literal()
917                        {
918                            let units = expr1.as_literal_str().unwrap();
919                            *e = match units.parse::<DateTimeUnits>() {
920                                Ok(units) => MirScalarExpr::CallUnary {
921                                    func: UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(
922                                        units,
923                                    )),
924                                    expr: Box::new(expr2.take()),
925                                },
926                                Err(_) => MirScalarExpr::literal(
927                                    Err(EvalError::UnknownUnits(units.into())),
928                                    e.typ(column_types).scalar_type,
929                                ),
930                            }
931                        } else if let BinaryFunc::ExtractTimestampTz(_) = *func
932                            && expr1.is_literal()
933                        {
934                            let units = expr1.as_literal_str().unwrap();
935                            *e = match units.parse::<DateTimeUnits>() {
936                                Ok(units) => MirScalarExpr::CallUnary {
937                                    func: UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(
938                                        units,
939                                    )),
940                                    expr: Box::new(expr2.take()),
941                                },
942                                Err(_) => MirScalarExpr::literal(
943                                    Err(EvalError::UnknownUnits(units.into())),
944                                    e.typ(column_types).scalar_type,
945                                ),
946                            }
947                        } else if let BinaryFunc::ExtractDate(_) = *func
948                            && expr1.is_literal()
949                        {
950                            let units = expr1.as_literal_str().unwrap();
951                            *e = match units.parse::<DateTimeUnits>() {
952                                Ok(units) => MirScalarExpr::CallUnary {
953                                    func: UnaryFunc::ExtractDate(func::ExtractDate(units)),
954                                    expr: Box::new(expr2.take()),
955                                },
956                                Err(_) => MirScalarExpr::literal(
957                                    Err(EvalError::UnknownUnits(units.into())),
958                                    e.typ(column_types).scalar_type,
959                                ),
960                            }
961                        } else if let BinaryFunc::DatePartInterval(_) = *func
962                            && expr1.is_literal()
963                        {
964                            let units = expr1.as_literal_str().unwrap();
965                            *e = match units.parse::<DateTimeUnits>() {
966                                Ok(units) => MirScalarExpr::CallUnary {
967                                    func: UnaryFunc::DatePartInterval(func::DatePartInterval(
968                                        units,
969                                    )),
970                                    expr: Box::new(expr2.take()),
971                                },
972                                Err(_) => MirScalarExpr::literal(
973                                    Err(EvalError::UnknownUnits(units.into())),
974                                    e.typ(column_types).scalar_type,
975                                ),
976                            }
977                        } else if let BinaryFunc::DatePartTime(_) = *func
978                            && expr1.is_literal()
979                        {
980                            let units = expr1.as_literal_str().unwrap();
981                            *e = match units.parse::<DateTimeUnits>() {
982                                Ok(units) => MirScalarExpr::CallUnary {
983                                    func: UnaryFunc::DatePartTime(func::DatePartTime(units)),
984                                    expr: Box::new(expr2.take()),
985                                },
986                                Err(_) => MirScalarExpr::literal(
987                                    Err(EvalError::UnknownUnits(units.into())),
988                                    e.typ(column_types).scalar_type,
989                                ),
990                            }
991                        } else if let BinaryFunc::DatePartTimestamp(_) = *func
992                            && expr1.is_literal()
993                        {
994                            let units = expr1.as_literal_str().unwrap();
995                            *e = match units.parse::<DateTimeUnits>() {
996                                Ok(units) => MirScalarExpr::CallUnary {
997                                    func: UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(
998                                        units,
999                                    )),
1000                                    expr: Box::new(expr2.take()),
1001                                },
1002                                Err(_) => MirScalarExpr::literal(
1003                                    Err(EvalError::UnknownUnits(units.into())),
1004                                    e.typ(column_types).scalar_type,
1005                                ),
1006                            }
1007                        } else if let BinaryFunc::DatePartTimestampTz(_) = *func
1008                            && expr1.is_literal()
1009                        {
1010                            let units = expr1.as_literal_str().unwrap();
1011                            *e = match units.parse::<DateTimeUnits>() {
1012                                Ok(units) => MirScalarExpr::CallUnary {
1013                                    func: UnaryFunc::DatePartTimestampTz(
1014                                        func::DatePartTimestampTz(units),
1015                                    ),
1016                                    expr: Box::new(expr2.take()),
1017                                },
1018                                Err(_) => MirScalarExpr::literal(
1019                                    Err(EvalError::UnknownUnits(units.into())),
1020                                    e.typ(column_types).scalar_type,
1021                                ),
1022                            }
1023                        } else if let BinaryFunc::DateTruncTimestamp(_) = *func
1024                            && expr1.is_literal()
1025                        {
1026                            let units = expr1.as_literal_str().unwrap();
1027                            *e = match units.parse::<DateTimeUnits>() {
1028                                Ok(units) => MirScalarExpr::CallUnary {
1029                                    func: UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(
1030                                        units,
1031                                    )),
1032                                    expr: Box::new(expr2.take()),
1033                                },
1034                                Err(_) => MirScalarExpr::literal(
1035                                    Err(EvalError::UnknownUnits(units.into())),
1036                                    e.typ(column_types).scalar_type,
1037                                ),
1038                            }
1039                        } else if let BinaryFunc::DateTruncTimestampTz(_) = *func
1040                            && expr1.is_literal()
1041                        {
1042                            let units = expr1.as_literal_str().unwrap();
1043                            *e = match units.parse::<DateTimeUnits>() {
1044                                Ok(units) => MirScalarExpr::CallUnary {
1045                                    func: UnaryFunc::DateTruncTimestampTz(
1046                                        func::DateTruncTimestampTz(units),
1047                                    ),
1048                                    expr: Box::new(expr2.take()),
1049                                },
1050                                Err(_) => MirScalarExpr::literal(
1051                                    Err(EvalError::UnknownUnits(units.into())),
1052                                    e.typ(column_types).scalar_type,
1053                                ),
1054                            }
1055                        } else if matches!(func, BinaryFunc::TimezoneTimestampBinary(_))
1056                            && expr1.is_literal()
1057                        {
1058                            // If the timezone argument is a literal, and we're applying the function on many rows at the same
1059                            // time we really don't want to parse it again and again, so we parse it once and embed it into the
1060                            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
1061                            let tz = expr1.as_literal_str().unwrap();
1062                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1063                                Ok(tz) => MirScalarExpr::CallUnary {
1064                                    func: UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz)),
1065                                    expr: Box::new(expr2.take()),
1066                                },
1067                                Err(err) => MirScalarExpr::literal(
1068                                    Err(err),
1069                                    e.typ(column_types).scalar_type,
1070                                ),
1071                            }
1072                        } else if matches!(func, BinaryFunc::TimezoneTimestampTzBinary(_))
1073                            && expr1.is_literal()
1074                        {
1075                            let tz = expr1.as_literal_str().unwrap();
1076                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1077                                Ok(tz) => MirScalarExpr::CallUnary {
1078                                    func: UnaryFunc::TimezoneTimestampTz(
1079                                        func::TimezoneTimestampTz(tz),
1080                                    ),
1081                                    expr: Box::new(expr2.take()),
1082                                },
1083                                Err(err) => MirScalarExpr::literal(
1084                                    Err(err),
1085                                    e.typ(column_types).scalar_type,
1086                                ),
1087                            }
1088                        } else if let BinaryFunc::ToCharTimestamp(_) = *func
1089                            && expr2.is_literal()
1090                        {
1091                            let format_str = expr2.as_literal_str().unwrap();
1092                            *e = MirScalarExpr::CallUnary {
1093                                func: UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
1094                                    format_string: format_str.to_string(),
1095                                    format: DateTimeFormat::compile(format_str),
1096                                }),
1097                                expr: Box::new(expr1.take()),
1098                            };
1099                        } else if let BinaryFunc::ToCharTimestampTz(_) = *func
1100                            && expr2.is_literal()
1101                        {
1102                            let format_str = expr2.as_literal_str().unwrap();
1103                            *e = MirScalarExpr::CallUnary {
1104                                func: UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
1105                                    format_string: format_str.to_string(),
1106                                    format: DateTimeFormat::compile(format_str),
1107                                }),
1108                                expr: Box::new(expr1.take()),
1109                            };
1110                        } else if matches!(*func, BinaryFunc::Eq(_) | BinaryFunc::NotEq(_))
1111                            && expr2 < expr1
1112                        {
1113                            // Canonically order elements so that deduplication works better.
1114                            // Also, the below `Literal([c1, c2]) = record_create(e1, e2)` matching
1115                            // relies on this canonical ordering.
1116                            mem::swap(expr1, expr2);
1117                        } else if let (
1118                            BinaryFunc::Eq(_),
1119                            MirScalarExpr::Literal(
1120                                Ok(lit_row),
1121                                ReprColumnType {
1122                                    scalar_type:
1123                                        ReprScalarType::Record {
1124                                            fields: field_types,
1125                                            ..
1126                                        },
1127                                    ..
1128                                },
1129                            ),
1130                            MirScalarExpr::CallVariadic {
1131                                func: VariadicFunc::RecordCreate(..),
1132                                exprs: rec_create_args,
1133                            },
1134                        ) = (&*func, &**expr1, &**expr2)
1135                        {
1136                            // Literal([c1, c2]) = record_create(e1, e2)
1137                            //  -->
1138                            // c1 = e1 AND c2 = e2
1139                            //
1140                            // (Records are represented as lists.)
1141                            //
1142                            // `MapFilterProject::literal_constraints` relies on this transform,
1143                            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
1144                            match lit_row.unpack_first() {
1145                                Datum::List(datum_list) => {
1146                                    *e = MirScalarExpr::call_variadic(
1147                                        And,
1148                                        datum_list
1149                                            .iter()
1150                                            .zip_eq(field_types)
1151                                            .zip_eq(rec_create_args)
1152                                            .map(|((d, typ), a)| {
1153                                                MirScalarExpr::literal_ok(
1154                                                    d,
1155                                                    typ.scalar_type.clone(),
1156                                                )
1157                                                .call_binary(a.clone(), func::Eq)
1158                                            })
1159                                            .collect(),
1160                                    );
1161                                }
1162                                _ => {}
1163                            }
1164                        } else if let (
1165                            BinaryFunc::Eq(_),
1166                            MirScalarExpr::CallVariadic {
1167                                func: VariadicFunc::RecordCreate(..),
1168                                exprs: rec_create_args1,
1169                            },
1170                            MirScalarExpr::CallVariadic {
1171                                func: VariadicFunc::RecordCreate(..),
1172                                exprs: rec_create_args2,
1173                            },
1174                        ) = (&*func, &**expr1, &**expr2)
1175                        {
1176                            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
1177                            //  -->
1178                            // a1 = b1 AND a2 = b2 AND ...
1179                            //
1180                            // This is similar to the previous reduction, but this one kicks in also
1181                            // when only some (or none) of the record fields are literals. This
1182                            // enables the discovery of literal constraints for those fields.
1183                            //
1184                            // Note that there is a similar decomposition in
1185                            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
1186                            // pipeline than the compilation of IN lists to `record_create`.
1187                            *e = MirScalarExpr::call_variadic(
1188                                And,
1189                                rec_create_args1
1190                                    .into_iter()
1191                                    .zip_eq(rec_create_args2)
1192                                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
1193                                    .collect(),
1194                            )
1195                        }
1196                    }
1197                    MirScalarExpr::CallVariadic { .. } => {
1198                        e.flatten_associative();
1199                        let (func, exprs) = match e {
1200                            MirScalarExpr::CallVariadic { func, exprs } => (func, exprs),
1201                            _ => unreachable!("`flatten_associative` shouldn't change node type"),
1202                        };
1203                        if *func == Coalesce.into() {
1204                            // If all inputs are null, output is null. This check must
1205                            // be done before `exprs.retain...` because `e.typ` requires
1206                            // > 0 `exprs` remain.
1207                            if exprs.iter().all(|expr| expr.is_literal_null()) {
1208                                *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1209                                return;
1210                            }
1211
1212                            // Remove any null values if not all values are null.
1213                            exprs.retain(|e| !e.is_literal_null());
1214
1215                            // Find the first argument that is a literal or non-nullable
1216                            // column. All arguments after it get ignored, so throw them
1217                            // away. This intentionally throws away errors that can
1218                            // never happen.
1219                            if let Some(i) = exprs
1220                                .iter()
1221                                .position(|e| e.is_literal() || !e.typ(column_types).nullable)
1222                            {
1223                                exprs.truncate(i + 1);
1224                            }
1225
1226                            // Deduplicate arguments in cases like `coalesce(#0, #0)`.
1227                            let mut prior_exprs = BTreeSet::new();
1228                            exprs.retain(|e| prior_exprs.insert(e.clone()));
1229
1230                            if exprs.len() == 1 {
1231                                // Only one argument, so the coalesce is a no-op.
1232                                *e = exprs[0].take();
1233                            }
1234                        } else if exprs.iter().all(|e| e.is_literal()) {
1235                            *e = eval(e);
1236                        } else if func.propagates_nulls()
1237                            && exprs.iter().any(|e| e.is_literal_null())
1238                        {
1239                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1240                        } else if let Some(err) = exprs.iter().find_map(|e| e.as_literal_err()) {
1241                            *e = MirScalarExpr::literal(
1242                                Err(err.clone()),
1243                                e.typ(column_types).scalar_type,
1244                            );
1245                        } else if *func == RegexpMatch.into()
1246                            && exprs[1].is_literal()
1247                            && exprs.get(2).map_or(true, |e| e.is_literal())
1248                        {
1249                            let needle = exprs[1].as_literal_str().unwrap();
1250                            let flags = match exprs.len() {
1251                                3 => exprs[2].as_literal_str().unwrap(),
1252                                _ => "",
1253                            };
1254                            *e = match func::build_regex(needle, flags) {
1255                                Ok(regex) => mem::take(exprs)
1256                                    .into_first()
1257                                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
1258                                Err(err) => MirScalarExpr::literal(
1259                                    Err(err),
1260                                    e.typ(column_types).scalar_type,
1261                                ),
1262                            };
1263                        } else if *func == RegexpReplace.into()
1264                            && exprs[1].is_literal()
1265                            && exprs.get(3).map_or(true, |e| e.is_literal())
1266                        {
1267                            let pattern = exprs[1].as_literal_str().unwrap();
1268                            let flags = exprs
1269                                .get(3)
1270                                .map_or("", |expr| expr.as_literal_str().unwrap());
1271                            let (limit, flags) = regexp_replace_parse_flags(flags);
1272
1273                            // The behavior of `regexp_replace` is that if the data is `NULL`, the
1274                            // function returns `NULL`, independently of whether the pattern or
1275                            // flags are correct. We need to check for this case and introduce an
1276                            // if-then-else on the error path to only surface the error if both
1277                            // the source and replacement inputs are non-NULL.
1278                            *e = match func::build_regex(pattern, &flags) {
1279                                Ok(regex) => {
1280                                    let mut exprs = mem::take(exprs);
1281                                    let replacement = exprs.swap_remove(2);
1282                                    let source = exprs.swap_remove(0);
1283                                    source.call_binary(
1284                                        replacement,
1285                                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
1286                                    )
1287                                }
1288                                Err(err) => {
1289                                    let mut exprs = mem::take(exprs);
1290                                    let replacement = exprs.swap_remove(2);
1291                                    let source = exprs.swap_remove(0);
1292                                    let scalar_type = e.typ(column_types).scalar_type;
1293                                    // We need to return `NULL` on `NULL` input, and error otherwise.
1294                                    source
1295                                        .call_is_null()
1296                                        .or(replacement.call_is_null())
1297                                        .if_then_else(
1298                                            MirScalarExpr::literal_null(scalar_type.clone()),
1299                                            MirScalarExpr::literal(Err(err), scalar_type),
1300                                        )
1301                                }
1302                            };
1303                        } else if *func == RegexpSplitToArray.into()
1304                            && exprs[1].is_literal()
1305                            && exprs.get(2).map_or(true, |e| e.is_literal())
1306                        {
1307                            let needle = exprs[1].as_literal_str().unwrap();
1308                            let flags = match exprs.len() {
1309                                3 => exprs[2].as_literal_str().unwrap(),
1310                                _ => "",
1311                            };
1312                            *e = match func::build_regex(needle, flags) {
1313                                Ok(regex) => mem::take(exprs).into_first().call_unary(
1314                                    UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(regex)),
1315                                ),
1316                                Err(err) => MirScalarExpr::literal(
1317                                    Err(err),
1318                                    e.typ(column_types).scalar_type,
1319                                ),
1320                            };
1321                        } else if *func == ListIndex.into() && is_list_create_call(&exprs[0]) {
1322                            // We are looking for ListIndex(ListCreate, literal), and eliminate
1323                            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
1324                            let ind_exprs = exprs.split_off(1);
1325                            let top_list_create = exprs.swap_remove(0);
1326                            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
1327                        } else if *func == Or.into() || *func == And.into() {
1328                            // Note: It's important that we have called `flatten_associative` above.
1329                            e.undistribute_and_or();
1330                            e.reduce_and_canonicalize_and_or();
1331                        } else if let VariadicFunc::TimezoneTimeVariadic(_) = func {
1332                            if exprs[0].is_literal() && exprs[2].is_literal_ok() {
1333                                let tz = exprs[0].as_literal_str().unwrap();
1334                                *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1335                                    Ok(tz) => MirScalarExpr::CallUnary {
1336                                        func: UnaryFunc::TimezoneTime(func::TimezoneTime {
1337                                            tz,
1338                                            wall_time: exprs[2]
1339                                                .as_literal()
1340                                                .unwrap()
1341                                                .unwrap()
1342                                                .unwrap_timestamptz()
1343                                                .naive_utc(),
1344                                        }),
1345                                        expr: Box::new(exprs[1].take()),
1346                                    },
1347                                    Err(err) => MirScalarExpr::literal(
1348                                        Err(err),
1349                                        e.typ(column_types).scalar_type,
1350                                    ),
1351                                }
1352                            }
1353                        }
1354                    }
1355                    MirScalarExpr::If { cond, then, els } => {
1356                        if let Some(literal) = cond.as_literal() {
1357                            match literal {
1358                                Ok(Datum::True) => *e = then.take(),
1359                                Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
1360                                Err(err) => {
1361                                    *e = MirScalarExpr::Literal(
1362                                        Err(err.clone()),
1363                                        then.typ(column_types)
1364                                            .union(&els.typ(column_types))
1365                                            .unwrap(),
1366                                    )
1367                                }
1368                                _ => unreachable!(),
1369                            }
1370                        } else if then == els {
1371                            *e = then.take();
1372                        } else if then.is_literal_ok()
1373                            && els.is_literal_ok()
1374                            && then.typ(column_types).scalar_type == ReprScalarType::Bool
1375                            && els.typ(column_types).scalar_type == ReprScalarType::Bool
1376                        {
1377                            match (then.as_literal(), els.as_literal()) {
1378                                // Note: NULLs from the condition should not be propagated to the result
1379                                // of the expression.
1380                                (Some(Ok(Datum::True)), _) => {
1381                                    // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
1382                                    // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
1383                                    *e = cond
1384                                        .clone()
1385                                        .call_is_null()
1386                                        .not()
1387                                        .and(cond.take())
1388                                        .or(els.take());
1389                                }
1390                                (Some(Ok(Datum::False)), _) => {
1391                                    // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
1392                                    // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
1393                                    *e = cond
1394                                        .clone()
1395                                        .not()
1396                                        .or(cond.take().call_is_null())
1397                                        .and(els.take());
1398                                }
1399                                (_, Some(Ok(Datum::True))) => {
1400                                    // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
1401                                    // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
1402                                    *e = cond
1403                                        .clone()
1404                                        .not()
1405                                        .or(cond.take().call_is_null())
1406                                        .or(then.take());
1407                                }
1408                                (_, Some(Ok(Datum::False))) => {
1409                                    // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
1410                                    // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
1411                                    *e = cond
1412                                        .clone()
1413                                        .call_is_null()
1414                                        .not()
1415                                        .and(cond.take())
1416                                        .and(then.take());
1417                                }
1418                                _ => {}
1419                            }
1420                        } else {
1421                            // Equivalent expression structure would allow us to push the `If` into the expression.
1422                            // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
1423                            //
1424                            // We have to also make sure that the expressions that will end up in
1425                            // the two `If` branches have unionable types. Otherwise, the `If` could
1426                            // not be typed by `typ`. An example where this could cause an issue is
1427                            // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
1428                            // range of input types. (In theory, we could still do the optimization
1429                            // in this case by inserting appropriate casts, but this corner case is
1430                            // not worth the complication for now.)
1431                            // See https://github.com/MaterializeInc/database-issues/issues/9182
1432                            match (&mut **then, &mut **els) {
1433                                (
1434                                    MirScalarExpr::CallUnary { func: f1, expr: e1 },
1435                                    MirScalarExpr::CallUnary { func: f2, expr: e2 },
1436                                ) if f1 == f2 && e1.typ(column_types) == e2.typ(column_types) => {
1437                                    *e = cond
1438                                        .take()
1439                                        .if_then_else(e1.take(), e2.take())
1440                                        .call_unary(f1.clone());
1441                                }
1442                                (
1443                                    MirScalarExpr::CallBinary {
1444                                        func: f1,
1445                                        expr1: e1a,
1446                                        expr2: e2a,
1447                                    },
1448                                    MirScalarExpr::CallBinary {
1449                                        func: f2,
1450                                        expr1: e1b,
1451                                        expr2: e2b,
1452                                    },
1453                                ) if f1 == f2
1454                                    && e1a == e1b
1455                                    && e2a.typ(column_types) == e2b.typ(column_types) =>
1456                                {
1457                                    *e = e1a.take().call_binary(
1458                                        cond.take().if_then_else(e2a.take(), e2b.take()),
1459                                        f1.clone(),
1460                                    );
1461                                }
1462                                (
1463                                    MirScalarExpr::CallBinary {
1464                                        func: f1,
1465                                        expr1: e1a,
1466                                        expr2: e2a,
1467                                    },
1468                                    MirScalarExpr::CallBinary {
1469                                        func: f2,
1470                                        expr1: e1b,
1471                                        expr2: e2b,
1472                                    },
1473                                ) if f1 == f2
1474                                    && e2a == e2b
1475                                    && e1a.typ(column_types) == e1b.typ(column_types) =>
1476                                {
1477                                    *e = cond
1478                                        .take()
1479                                        .if_then_else(e1a.take(), e1b.take())
1480                                        .call_binary(e2a.take(), f1.clone());
1481                                }
1482                                _ => {}
1483                            }
1484                        }
1485                    }
1486                },
1487            );
1488        }
1489
1490        /* #region `reduce_list_create_list_index_literal` and helper functions */
1491
1492        fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
1493            if let MirScalarExpr::CallVariadic {
1494                func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
1495                ..
1496            } = list_create
1497            {
1498                ReprScalarType::from(typ)
1499            } else {
1500                unreachable!()
1501            }
1502        }
1503
1504        fn is_list_create_call(expr: &MirScalarExpr) -> bool {
1505            matches!(
1506                expr,
1507                MirScalarExpr::CallVariadic {
1508                    func: VariadicFunc::ListCreate(..),
1509                    ..
1510                }
1511            )
1512        }
1513
1514        /// Partial-evaluates a list indexing with a literal directly after a list creation.
1515        ///
1516        /// Multi-dimensional lists are handled by a single call to this function, with multiple
1517        /// elements in index_exprs (of which not all need to be literals), and nested ListCreates
1518        /// in list_create_to_reduce.
1519        ///
1520        /// # Examples
1521        ///
1522        /// `LIST[f1,f2][2]` --> `f2`.
1523        ///
1524        /// A multi-dimensional list, with only some of the indexes being literals:
1525        /// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
1526        ///
1527        /// See more examples in list.slt.
1528        fn reduce_list_create_list_index_literal(
1529            mut list_create_to_reduce: MirScalarExpr,
1530            mut index_exprs: Vec<MirScalarExpr>,
1531        ) -> MirScalarExpr {
1532            // We iterate over the index_exprs and remove literals, but keep non-literals.
1533            // When we encounter a non-literal, we need to dig into the nested ListCreates:
1534            // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
1535            // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
1536            // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
1537            // that are at the current level (except those that disappeared due to
1538            // literals at earlier levels), index into them with the literal, and change each
1539            // element in `list_create_mut_refs` to the result.
1540            // We also record mut refs to all the earlier `element_type` references that we have
1541            // seen in ListCreate calls, because when we process a literal index, we need to remove
1542            // one layer of list type from all these earlier ListCreate `element_type`s.
1543            let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
1544            let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
1545            let mut i = 0;
1546            while i < index_exprs.len()
1547                && list_create_mut_refs
1548                    .iter()
1549                    .all(|lc| is_list_create_call(lc))
1550            {
1551                if index_exprs[i].is_literal_ok() {
1552                    // We can remove this index.
1553                    let removed_index = index_exprs.remove(i);
1554                    let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
1555                        Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
1556                        _ => unreachable!(), // always an Int64, see plan_index_list
1557                    };
1558                    // For each list_create referenced by list_create_mut_refs, substitute it by its
1559                    // `index`th argument (or null).
1560                    for list_create in &mut list_create_mut_refs {
1561                        let list_create_args = match list_create {
1562                            MirScalarExpr::CallVariadic {
1563                                func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
1564                                exprs,
1565                            } => exprs,
1566                            _ => unreachable!(), // func cannot be anything else than a ListCreate
1567                        };
1568                        // ListIndex gives null on an out-of-bounds index
1569                        if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap()
1570                        {
1571                            let index: usize = index_i64.try_into().unwrap();
1572                            **list_create = list_create_args.swap_remove(index);
1573                        } else {
1574                            let typ = list_create_type(list_create);
1575                            **list_create = MirScalarExpr::literal_null(typ);
1576                        }
1577                    }
1578                    // Peel one layer off of each of the earlier element types.
1579                    for t in earlier_list_create_types.iter_mut() {
1580                        if let SqlScalarType::List {
1581                            element_type,
1582                            custom_id: _,
1583                        } = t
1584                        {
1585                            **t = *element_type.clone();
1586                            // These are not the same types anymore, so remove custom_ids all the
1587                            // way down.
1588                            let mut u = &mut **t;
1589                            while let SqlScalarType::List {
1590                                element_type,
1591                                custom_id,
1592                            } = u
1593                            {
1594                                *custom_id = None;
1595                                u = &mut **element_type;
1596                            }
1597                        } else {
1598                            unreachable!("already matched below");
1599                        }
1600                    }
1601                } else {
1602                    // We can't remove this index, so we can't reduce any of the ListCreates at this
1603                    // level. So we change list_create_mut_refs to refer to all the arguments of all
1604                    // the ListCreates currently referenced by list_create_mut_refs.
1605                    list_create_mut_refs = list_create_mut_refs
1606                        .into_iter()
1607                        .flat_map(|list_create| match list_create {
1608                            MirScalarExpr::CallVariadic {
1609                                func: VariadicFunc::ListCreate(ListCreate { elem_type }),
1610                                exprs: list_create_args,
1611                            } => {
1612                                earlier_list_create_types.push(elem_type);
1613                                list_create_args
1614                            }
1615                            // func cannot be anything else than a ListCreate
1616                            _ => unreachable!(),
1617                        })
1618                        .collect();
1619                    i += 1; // next index_expr
1620                }
1621            }
1622            // If all list indexes have been evaluated, return the reduced expression.
1623            // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
1624            if index_exprs.is_empty() {
1625                assert_eq!(list_create_mut_refs.len(), 1);
1626                list_create_to_reduce
1627            } else {
1628                MirScalarExpr::call_variadic(
1629                    ListIndex,
1630                    std::iter::once(list_create_to_reduce)
1631                        .chain(index_exprs)
1632                        .collect(),
1633                )
1634            }
1635        }
1636
1637        /* #endregion */
1638    }
1639
1640    /// Decompose an IsNull expression into a disjunction of
1641    /// simpler expressions.
1642    ///
1643    /// Assumes that `self` is the expression inside of an IsNull.
1644    /// Returns `Some(expressions)` if the outer IsNull is to be
1645    /// replaced by some other expression. Note: if it returns
1646    /// None, it might still have mutated *self.
1647    fn decompose_is_null(&mut self) -> Option<MirScalarExpr> {
1648        // TODO: allow simplification of unmaterializable functions
1649
1650        match self {
1651            MirScalarExpr::CallUnary {
1652                func,
1653                expr: inner_expr,
1654            } => {
1655                if !func.introduces_nulls() {
1656                    if func.propagates_nulls() {
1657                        *self = inner_expr.take();
1658                        return self.decompose_is_null();
1659                    } else {
1660                        // We can simplify to `false`, because the function simply can't produce
1661                        // nulls at all. This is because
1662                        // - !propagates_nulls means that the input type of the Rust function is not
1663                        //   nullable, so the automatic null propagation won't kick in;
1664                        // - !introduces_nulls means that the output type of the Rust function is
1665                        //   not nullable, so the Rust function can't produce a null manually either.
1666                        //
1667                        // Note that we can't do this same optimization for binary and variadic
1668                        // functions. This is because for binary and variadic functions the value of
1669                        // propagates_nulls and introduces_nulls is not derived solely from the
1670                        // input/output type nullabilities, but instead depends on what the Rust
1671                        // function does. For example, list concatenation neither introduces nor
1672                        // propagates nulls, but it can produce a null:
1673                        // - It does not introduce nulls, because if both input lists are not null,
1674                        //   then it will produce a list.
1675                        // - It does not propagate nulls, because giving a null as just one of the
1676                        //   arguments returns the other argument instead of null.
1677                        // - It does produce a null if both arguments are null.
1678                        return Some(MirScalarExpr::literal_false());
1679                    }
1680                }
1681            }
1682            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1683                // (<expr1> <op> <expr2>) IS NULL can often be simplified to
1684                // (<expr1> IS NULL) OR (<expr2> IS NULL).
1685                if func.propagates_nulls() && !func.introduces_nulls() {
1686                    let expr1 = expr1.take().call_is_null();
1687                    let expr2 = expr2.take().call_is_null();
1688                    return Some(expr1.or(expr2));
1689                }
1690            }
1691            MirScalarExpr::CallVariadic { func, exprs } => {
1692                if func.propagates_nulls() && !func.introduces_nulls() {
1693                    let exprs = exprs.into_iter().map(|e| e.take().call_is_null()).collect();
1694                    return Some(MirScalarExpr::call_variadic(Or, exprs));
1695                }
1696            }
1697            _ => {}
1698        }
1699
1700        None
1701    }
1702
1703    /// Flattens a chain of calls to associative variadic functions
1704    /// (For example: ORs or ANDs)
1705    pub fn flatten_associative(&mut self) {
1706        match self {
1707            MirScalarExpr::CallVariadic {
1708                exprs: outer_operands,
1709                func: outer_func,
1710            } if outer_func.is_associative() => {
1711                *outer_operands = outer_operands
1712                    .into_iter()
1713                    .flat_map(|o| {
1714                        if let MirScalarExpr::CallVariadic {
1715                            exprs: inner_operands,
1716                            func: inner_func,
1717                        } = o
1718                        {
1719                            if *inner_func == *outer_func {
1720                                mem::take(inner_operands)
1721                            } else {
1722                                vec![o.take()]
1723                            }
1724                        } else {
1725                            vec![o.take()]
1726                        }
1727                    })
1728                    .collect();
1729            }
1730            _ => {}
1731        }
1732    }
1733
1734    /* #region AND/OR canonicalization and transformations  */
1735
1736    /// Canonicalizes AND/OR, and does some straightforward simplifications
1737    fn reduce_and_canonicalize_and_or(&mut self) {
1738        // We do this until fixed point, because after undistribute_and_or calls us, it relies on
1739        // the property that self is not an 1-arg AND/OR. Just one application of our loop body
1740        // can't ensure this, because the application itself might create a 1-arg AND/OR.
1741        let mut old_self = MirScalarExpr::column(0);
1742        while old_self != *self {
1743            old_self = self.clone();
1744            match self {
1745                MirScalarExpr::CallVariadic {
1746                    func: func @ (VariadicFunc::And(_) | VariadicFunc::Or(_)),
1747                    exprs,
1748                } => {
1749                    // Canonically order elements so that various deduplications work better,
1750                    // e.g., in undistribute_and_or.
1751                    // Also, extract_equal_or_both_null_inner depends on the args being sorted.
1752                    exprs.sort();
1753
1754                    // x AND/OR x --> x
1755                    exprs.dedup(); // this also needs the above sorting
1756
1757                    if exprs.len() == 1 {
1758                        // AND/OR of 1 argument evaluates to that argument
1759                        *self = exprs.swap_remove(0);
1760                    } else if exprs.len() == 0 {
1761                        // AND/OR of 0 arguments evaluates to true/false
1762                        *self = func.unit_of_and_or();
1763                    } else if exprs.iter().any(|e| *e == func.zero_of_and_or()) {
1764                        // short-circuiting
1765                        *self = func.zero_of_and_or();
1766                    } else {
1767                        // a AND true --> a
1768                        // a OR false --> a
1769                        exprs.retain(|e| *e != func.unit_of_and_or());
1770                    }
1771                }
1772                _ => {}
1773            }
1774        }
1775    }
1776
1777    /// Transforms !(a && b) into !a || !b, and !(a || b) into !a && !b
1778    fn demorgans(&mut self) {
1779        if let MirScalarExpr::CallUnary {
1780            expr: inner,
1781            func: UnaryFunc::Not(func::Not),
1782        } = self
1783        {
1784            inner.flatten_associative();
1785            match &mut **inner {
1786                MirScalarExpr::CallVariadic {
1787                    func: inner_func @ (VariadicFunc::And(_) | VariadicFunc::Or(_)),
1788                    exprs,
1789                } => {
1790                    *inner_func = inner_func.switch_and_or();
1791                    *exprs = exprs.into_iter().map(|e| e.take().not()).collect();
1792                    *self = (*inner).take(); // Removes the outer not
1793                }
1794                _ => {}
1795            }
1796        }
1797    }
1798
1799    /// AND/OR undistribution (factoring out) to apply at each `MirScalarExpr`.
1800    ///
1801    /// This method attempts to apply one of the [distribution laws][distributivity]
1802    /// (in a direction opposite to the their name):
1803    /// ```text
1804    /// (a && b) || (a && c) --> a && (b || c)  // Undistribute-OR
1805    /// (a || b) && (a || c) --> a || (b && c)  // Undistribute-AND
1806    /// ```
1807    /// or one of their corresponding two [absorption law][absorption] special
1808    /// cases:
1809    /// ```text
1810    /// a || (a && c)  -->  a  // Absorb-OR
1811    /// a && (a || c)  -->  a  // Absorb-AND
1812    /// ```
1813    ///
1814    /// The method also works with more than 2 arguments at the top, e.g.
1815    /// ```text
1816    /// (a && b) || (a && c) || (a && d)  -->  a && (b || c || d)
1817    /// ```
1818    /// It can also factor out only a subset of the top arguments, e.g.
1819    /// ```text
1820    /// (a && b) || (a && c) || (d && e)  -->  (a && (b || c)) || (d && e)
1821    /// ```
1822    ///
1823    /// Note that sometimes there are two overlapping possibilities to factor
1824    /// out from, e.g.
1825    /// ```text
1826    /// (a && b) || (a && c) || (d && c)
1827    /// ```
1828    /// Here we can factor out `a` from from the 1. and 2. terms, or we can
1829    /// factor out `c` from the 2. and 3. terms. One of these might lead to
1830    /// more/better undistribution opportunities later, but we just pick one
1831    /// locally, because recursively trying out all of them would lead to
1832    /// exponential run time.
1833    ///
1834    /// The local heuristic is that we prefer a candidate that leads to an
1835    /// absorption, or if there is no such one then we simply pick the first. In
1836    /// case of multiple absorption candidates, it doesn't matter which one we
1837    /// pick, because applying an absorption cannot adversely effect the
1838    /// possibility of applying other absorptions.
1839    ///
1840    /// # Assumption
1841    ///
1842    /// Assumes that nested chains of AND/OR applications are flattened (this
1843    /// can be enforced with [`Self::flatten_associative`]).
1844    ///
1845    /// # Examples
1846    ///
1847    /// Absorb-OR:
1848    /// ```text
1849    /// a || (a && c) || (a && d)
1850    /// -->
1851    /// a && (true || c || d)
1852    /// -->
1853    /// a && true
1854    /// -->
1855    /// a
1856    /// ```
1857    /// Here only the first step is performed by this method. The rest is done
1858    /// by [`Self::reduce_and_canonicalize_and_or`] called after us in
1859    /// `reduce()`.
1860    ///
1861    /// [distributivity]: https://en.wikipedia.org/wiki/Distributive_property
1862    /// [absorption]: https://en.wikipedia.org/wiki/Absorption_law
1863    fn undistribute_and_or(&mut self) {
1864        // It wouldn't be strictly necessary to wrap this fn in this loop, because `reduce()` calls
1865        // us in a loop anyway. However, `reduce()` tries to do many other things, so the loop here
1866        // improves performance when there are several undistributions to apply in sequence, which
1867        // can occur in `CanonicalizeMfp` when undoing the DNF.
1868        let mut old_self = MirScalarExpr::column(0);
1869        while old_self != *self {
1870            old_self = self.clone();
1871            self.reduce_and_canonicalize_and_or(); // We don't want to deal with 1-arg AND/OR at the top
1872            if let MirScalarExpr::CallVariadic {
1873                exprs: outer_operands,
1874                func: outer_func @ (VariadicFunc::Or(_) | VariadicFunc::And(_)),
1875            } = self
1876            {
1877                let inner_func = outer_func.switch_and_or();
1878
1879                // Make sure that each outer operand is a call to inner_func, by wrapping in a 1-arg
1880                // call if necessary.
1881                outer_operands.iter_mut().for_each(|o| {
1882                    if !matches!(o, MirScalarExpr::CallVariadic {func: f, ..} if *f == inner_func) {
1883                        *o = MirScalarExpr::CallVariadic {
1884                            func: inner_func.clone(),
1885                            exprs: vec![o.take()],
1886                        };
1887                    }
1888                });
1889
1890                let mut inner_operands_refs: Vec<&mut Vec<MirScalarExpr>> = outer_operands
1891                    .iter_mut()
1892                    .map(|o| match o {
1893                        MirScalarExpr::CallVariadic { func: f, exprs } if *f == inner_func => exprs,
1894                        _ => unreachable!(), // the wrapping made sure that we'll get a match
1895                    })
1896                    .collect();
1897
1898                // Find inner operands to undistribute, i.e., which are in _all_ of the outer operands.
1899                let mut intersection = inner_operands_refs
1900                    .iter()
1901                    .map(|v| (*v).clone())
1902                    .reduce(|ops1, ops2| ops1.into_iter().filter(|e| ops2.contains(e)).collect())
1903                    .unwrap();
1904                intersection.sort();
1905                intersection.dedup();
1906
1907                if !intersection.is_empty() {
1908                    // Factor out the intersection from all the top-level args.
1909
1910                    // Remove the intersection from each inner operand vector.
1911                    inner_operands_refs
1912                        .iter_mut()
1913                        .for_each(|ops| (**ops).retain(|o| !intersection.contains(o)));
1914
1915                    // Simplify terms that now have only 0 or 1 args due to removing the intersection.
1916                    outer_operands
1917                        .iter_mut()
1918                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1919
1920                    // Add the intersection at the beginning
1921                    *self = MirScalarExpr::CallVariadic {
1922                        func: inner_func,
1923                        exprs: intersection.into_iter().chain_one(self.clone()).collect(),
1924                    };
1925                } else {
1926                    // If the intersection was empty, that means that there is nothing we can factor out
1927                    // from _all_ the top-level args. However, we might still find something to factor
1928                    // out from a subset of the top-level args. To find such an opportunity, we look for
1929                    // duplicates across all inner args, e.g. if we have
1930                    // `(...) OR (... AND `a` AND ...) OR (...) OR (... AND `a` AND ...)`
1931                    // then we'll find that `a` occurs in more than one top-level arg, so
1932                    // `indexes_to_undistribute` will point us to the 2. and 4. top-level args.
1933
1934                    // Create (inner_operand, index) pairs, where the index is the position in
1935                    // outer_operands
1936                    let all_inner_operands = inner_operands_refs
1937                        .iter()
1938                        .enumerate()
1939                        .flat_map(|(i, inner_vec)| inner_vec.iter().map(move |a| ((*a).clone(), i)))
1940                        .sorted()
1941                        .collect_vec();
1942
1943                    // Find inner operand expressions that occur in more than one top-level arg.
1944                    // Each inner vector in `undistribution_opportunities` will belong to one such inner
1945                    // operand expression, and it is a set of indexes pointing to top-level args where
1946                    // that inner operand occurs.
1947                    let undistribution_opportunities = all_inner_operands
1948                        .iter()
1949                        .chunk_by(|(a, _i)| a)
1950                        .into_iter()
1951                        .map(|(_a, g)| g.map(|(_a, i)| *i).sorted().dedup().collect_vec())
1952                        .filter(|g| g.len() > 1)
1953                        .collect_vec();
1954
1955                    // Choose one of the inner vectors from `undistribution_opportunities`.
1956                    let indexes_to_undistribute = undistribution_opportunities
1957                        .iter()
1958                        // Let's prefer index sets that directly lead to an absorption.
1959                        .find(|index_set| {
1960                            index_set
1961                                .iter()
1962                                .any(|i| inner_operands_refs.get(*i).unwrap().len() == 1)
1963                        })
1964                        // If we didn't find any absorption, then any index set will do.
1965                        .or_else(|| undistribution_opportunities.first())
1966                        .cloned();
1967
1968                    // In any case, undo the 1-arg wrapping that we did at the beginning.
1969                    outer_operands
1970                        .iter_mut()
1971                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1972
1973                    if let Some(indexes_to_undistribute) = indexes_to_undistribute {
1974                        // Found something to undistribute from a subset of the outer operands.
1975                        // We temporarily remove these from outer_operands, call ourselves on it, and
1976                        // then push back the result.
1977                        let mut undistribute_from = MirScalarExpr::CallVariadic {
1978                            func: outer_func.clone(),
1979                            exprs: swap_remove_multiple(outer_operands, indexes_to_undistribute),
1980                        };
1981                        // By construction, the recursive call is guaranteed to hit
1982                        // the `!intersection.is_empty()` branch.
1983                        undistribute_from.undistribute_and_or();
1984                        // Append the undistributed result to outer operands that were not included in
1985                        // indexes_to_undistribute.
1986                        outer_operands.push(undistribute_from);
1987                    }
1988                }
1989            }
1990        }
1991    }
1992
1993    /* #endregion */
1994
1995    /// Adds any columns that *must* be non-Null for `self` to be non-Null.
1996    pub fn non_null_requirements(&self, columns: &mut BTreeSet<usize>) {
1997        match self {
1998            MirScalarExpr::Column(col, _name) => {
1999                columns.insert(*col);
2000            }
2001            MirScalarExpr::Literal(..) => {}
2002            MirScalarExpr::CallUnmaterializable(_) => (),
2003            MirScalarExpr::CallUnary { func, expr } => {
2004                if func.propagates_nulls() {
2005                    expr.non_null_requirements(columns);
2006                }
2007            }
2008            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2009                if func.propagates_nulls() {
2010                    expr1.non_null_requirements(columns);
2011                    expr2.non_null_requirements(columns);
2012                }
2013            }
2014            MirScalarExpr::CallVariadic { func, exprs } => {
2015                if func.propagates_nulls() {
2016                    for expr in exprs {
2017                        expr.non_null_requirements(columns);
2018                    }
2019                }
2020            }
2021            MirScalarExpr::If { .. } => (),
2022        }
2023    }
2024
2025    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2026        let repr_column_types = column_types.iter().map(ReprColumnType::from).collect_vec();
2027        SqlColumnType::from_repr(&self.typ(&repr_column_types))
2028    }
2029
2030    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2031        match self {
2032            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
2033            MirScalarExpr::Literal(_, typ) => typ.clone(),
2034            MirScalarExpr::CallUnmaterializable(func) => func.output_type(),
2035            MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)),
2036            MirScalarExpr::CallBinary { expr1, expr2, func } => {
2037                func.output_type(&[expr1.typ(column_types), expr2.typ(column_types)])
2038            }
2039            MirScalarExpr::CallVariadic { exprs, func } => {
2040                func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect())
2041            }
2042            MirScalarExpr::If { cond: _, then, els } => {
2043                let then_type = then.typ(column_types);
2044                let else_type = els.typ(column_types);
2045                then_type.union(&else_type).unwrap()
2046            }
2047        }
2048    }
2049
2050    pub fn eval<'a>(
2051        &'a self,
2052        datums: &[Datum<'a>],
2053        temp_storage: &'a RowArena,
2054    ) -> Result<Datum<'a>, EvalError> {
2055        match self {
2056            MirScalarExpr::Column(index, _name) => Ok(datums[*index]),
2057            MirScalarExpr::Literal(res, _column_type) => match res {
2058                Ok(row) => Ok(row.unpack_first()),
2059                Err(e) => Err(e.clone()),
2060            },
2061            // Unmaterializable functions must be transformed away before
2062            // evaluation. Their purpose is as a placeholder for data that is
2063            // not known at plan time but can be inlined before runtime.
2064            MirScalarExpr::CallUnmaterializable(x) => Err(EvalError::Internal(
2065                format!("cannot evaluate unmaterializable function: {:?}", x).into(),
2066            )),
2067            MirScalarExpr::CallUnary { func, expr } => func.eval(datums, temp_storage, expr),
2068            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2069                func.eval(datums, temp_storage, &[expr1, expr2])
2070            }
2071            MirScalarExpr::CallVariadic { func, exprs } => func.eval(datums, temp_storage, exprs),
2072            MirScalarExpr::If { cond, then, els } => match cond.eval(datums, temp_storage)? {
2073                Datum::True => then.eval(datums, temp_storage),
2074                Datum::False | Datum::Null => els.eval(datums, temp_storage),
2075                d => Err(EvalError::Internal(
2076                    format!("if condition evaluated to non-boolean datum: {:?}", d).into(),
2077                )),
2078            },
2079        }
2080    }
2081
2082    /// True iff the expression contains
2083    /// `UnmaterializableFunc::MzNow`.
2084    pub fn contains_temporal(&self) -> bool {
2085        let mut contains = false;
2086        self.visit_pre(|e| {
2087            if let MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2088                contains = true;
2089            }
2090        });
2091        contains
2092    }
2093
2094    /// True iff the expression contains an `UnmaterializableFunc`.
2095    pub fn contains_unmaterializable(&self) -> bool {
2096        let mut contains = false;
2097        self.visit_pre(|e| {
2098            if let MirScalarExpr::CallUnmaterializable(_) = e {
2099                contains = true;
2100            }
2101        });
2102        contains
2103    }
2104
2105    /// True iff the expression contains an `UnmaterializableFunc` that is not in the `exceptions`
2106    /// list.
2107    pub fn contains_unmaterializable_except(&self, exceptions: &[UnmaterializableFunc]) -> bool {
2108        let mut contains = false;
2109        self.visit_pre(|e| match e {
2110            MirScalarExpr::CallUnmaterializable(f) if !exceptions.contains(f) => contains = true,
2111            _ => (),
2112        });
2113        contains
2114    }
2115
2116    /// True iff the expression contains a `Column`.
2117    pub fn contains_column(&self) -> bool {
2118        let mut contains = false;
2119        self.visit_pre(|e| {
2120            if let MirScalarExpr::Column(_col, _name) = e {
2121                contains = true;
2122            }
2123        });
2124        contains
2125    }
2126
2127    /// True iff the expression contains a `Dummy`.
2128    pub fn contains_dummy(&self) -> bool {
2129        let mut contains = false;
2130        self.visit_pre(|e| {
2131            if let MirScalarExpr::Literal(row, _) = e {
2132                if let Ok(row) = row {
2133                    contains |= row.iter().any(|d| d.contains_dummy());
2134                }
2135            }
2136        });
2137        contains
2138    }
2139
2140    /// The size of the expression as a tree.
2141    pub fn size(&self) -> usize {
2142        let mut size = 0;
2143        self.visit_pre(&mut |_: &MirScalarExpr| {
2144            size += 1;
2145        });
2146        size
2147    }
2148}
2149
2150impl MirScalarExpr {
2151    /// True iff evaluation could possibly error on non-error input `Datum`.
2152    pub fn could_error(&self) -> bool {
2153        match self {
2154            MirScalarExpr::Column(_col, _name) => false,
2155            MirScalarExpr::Literal(row, ..) => row.is_err(),
2156            MirScalarExpr::CallUnmaterializable(_) => true,
2157            MirScalarExpr::CallUnary { func, expr } => func.could_error() || expr.could_error(),
2158            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2159                func.could_error() || expr1.could_error() || expr2.could_error()
2160            }
2161            MirScalarExpr::CallVariadic { func, exprs } => {
2162                func.could_error() || exprs.iter().any(|e| e.could_error())
2163            }
2164            MirScalarExpr::If { cond, then, els } => {
2165                cond.could_error() || then.could_error() || els.could_error()
2166            }
2167        }
2168    }
2169}
2170
2171impl VisitChildren<Self> for MirScalarExpr {
2172    fn visit_children<F>(&self, mut f: F)
2173    where
2174        F: FnMut(&Self),
2175    {
2176        use MirScalarExpr::*;
2177        match self {
2178            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2179            CallUnary { expr, .. } => {
2180                f(expr);
2181            }
2182            CallBinary { expr1, expr2, .. } => {
2183                f(expr1);
2184                f(expr2);
2185            }
2186            CallVariadic { exprs, .. } => {
2187                for expr in exprs {
2188                    f(expr);
2189                }
2190            }
2191            If { cond, then, els } => {
2192                f(cond);
2193                f(then);
2194                f(els);
2195            }
2196        }
2197    }
2198
2199    fn visit_mut_children<F>(&mut self, mut f: F)
2200    where
2201        F: FnMut(&mut Self),
2202    {
2203        use MirScalarExpr::*;
2204        match self {
2205            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2206            CallUnary { expr, .. } => {
2207                f(expr);
2208            }
2209            CallBinary { expr1, expr2, .. } => {
2210                f(expr1);
2211                f(expr2);
2212            }
2213            CallVariadic { exprs, .. } => {
2214                for expr in exprs {
2215                    f(expr);
2216                }
2217            }
2218            If { cond, then, els } => {
2219                f(cond);
2220                f(then);
2221                f(els);
2222            }
2223        }
2224    }
2225
2226    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2227    where
2228        F: FnMut(&Self) -> Result<(), E>,
2229        E: From<RecursionLimitError>,
2230    {
2231        use MirScalarExpr::*;
2232        match self {
2233            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2234            CallUnary { expr, .. } => {
2235                f(expr)?;
2236            }
2237            CallBinary { expr1, expr2, .. } => {
2238                f(expr1)?;
2239                f(expr2)?;
2240            }
2241            CallVariadic { exprs, .. } => {
2242                for expr in exprs {
2243                    f(expr)?;
2244                }
2245            }
2246            If { cond, then, els } => {
2247                f(cond)?;
2248                f(then)?;
2249                f(els)?;
2250            }
2251        }
2252        Ok(())
2253    }
2254
2255    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2256    where
2257        F: FnMut(&mut Self) -> Result<(), E>,
2258        E: From<RecursionLimitError>,
2259    {
2260        use MirScalarExpr::*;
2261        match self {
2262            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2263            CallUnary { expr, .. } => {
2264                f(expr)?;
2265            }
2266            CallBinary { expr1, expr2, .. } => {
2267                f(expr1)?;
2268                f(expr2)?;
2269            }
2270            CallVariadic { exprs, .. } => {
2271                for expr in exprs {
2272                    f(expr)?;
2273                }
2274            }
2275            If { cond, then, els } => {
2276                f(cond)?;
2277                f(then)?;
2278                f(els)?;
2279            }
2280        }
2281        Ok(())
2282    }
2283}
2284
2285impl MirScalarExpr {
2286    /// Iterates through references to child expressions.
2287    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2288        let mut first = None;
2289        let mut second = None;
2290        let mut third = None;
2291        let mut variadic = None;
2292
2293        use MirScalarExpr::*;
2294        match self {
2295            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2296            CallUnary { expr, .. } => {
2297                first = Some(&**expr);
2298            }
2299            CallBinary { expr1, expr2, .. } => {
2300                first = Some(&**expr1);
2301                second = Some(&**expr2);
2302            }
2303            CallVariadic { exprs, .. } => {
2304                variadic = Some(exprs);
2305            }
2306            If { cond, then, els } => {
2307                first = Some(&**cond);
2308                second = Some(&**then);
2309                third = Some(&**els);
2310            }
2311        }
2312
2313        first
2314            .into_iter()
2315            .chain(second)
2316            .chain(third)
2317            .chain(variadic.into_iter().flatten())
2318    }
2319
2320    /// Iterates through mutable references to child expressions.
2321    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2322        let mut first = None;
2323        let mut second = None;
2324        let mut third = None;
2325        let mut variadic = None;
2326
2327        use MirScalarExpr::*;
2328        match self {
2329            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2330            CallUnary { expr, .. } => {
2331                first = Some(&mut **expr);
2332            }
2333            CallBinary { expr1, expr2, .. } => {
2334                first = Some(&mut **expr1);
2335                second = Some(&mut **expr2);
2336            }
2337            CallVariadic { exprs, .. } => {
2338                variadic = Some(exprs);
2339            }
2340            If { cond, then, els } => {
2341                first = Some(&mut **cond);
2342                second = Some(&mut **then);
2343                third = Some(&mut **els);
2344            }
2345        }
2346
2347        first
2348            .into_iter()
2349            .chain(second)
2350            .chain(third)
2351            .chain(variadic.into_iter().flatten())
2352    }
2353
2354    /// Visits all subexpressions in DFS preorder.
2355    pub fn visit_pre<F>(&self, mut f: F)
2356    where
2357        F: FnMut(&Self),
2358    {
2359        let mut worklist = vec![self];
2360        while let Some(e) = worklist.pop() {
2361            f(e);
2362            worklist.extend(e.children().rev());
2363        }
2364    }
2365
2366    /// Iterative pre-order visitor.
2367    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2368        let mut worklist = vec![self];
2369        while let Some(expr) = worklist.pop() {
2370            f(expr);
2371            worklist.extend(expr.children_mut().rev());
2372        }
2373    }
2374}
2375
2376/// Filter characteristics that are used for ordering join inputs.
2377/// This can be created for a `Vec<MirScalarExpr>`, which represents an AND of predicates.
2378///
2379/// The fields are ordered based on heuristic assumptions about their typical selectivity, so that
2380/// Ord gives the right ordering for join inputs. Bigger is better, i.e., will tend to come earlier
2381/// than other inputs.
2382#[derive(
2383    Eq,
2384    PartialEq,
2385    Ord,
2386    PartialOrd,
2387    Debug,
2388    Clone,
2389    Serialize,
2390    Deserialize,
2391    Hash,
2392    MzReflect
2393)]
2394pub struct FilterCharacteristics {
2395    // `<expr> = <literal>` appears in the filter.
2396    // Excludes cases where NOT appears anywhere above the literal equality.
2397    literal_equality: bool,
2398    // (Assuming a random string of lower-case characters, `LIKE 'a%'` has a selectivity of 1/26.)
2399    like: bool,
2400    is_null: bool,
2401    // Number of Vec elements that involve inequality predicates. (A BETWEEN is represented as two
2402    // inequality predicates.)
2403    // Excludes cases where NOT appears around the literal inequality.
2404    // Note that for inequality predicates, some databases assume 1/3 selectivity in the absence of
2405    // concrete statistics.
2406    literal_inequality: usize,
2407    /// Any filter, except ones involving `IS NOT NULL`, because those are too common.
2408    /// Can be true by itself, or any other field being true can also make this true.
2409    /// `NOT LIKE` is only in this category.
2410    /// `!=` is only in this category.
2411    /// `NOT (a = b)` is turned into `!=` by `reduce` before us!
2412    any_filter: bool,
2413}
2414
2415impl BitOrAssign for FilterCharacteristics {
2416    fn bitor_assign(&mut self, rhs: Self) {
2417        self.literal_equality |= rhs.literal_equality;
2418        self.like |= rhs.like;
2419        self.is_null |= rhs.is_null;
2420        self.literal_inequality += rhs.literal_inequality;
2421        self.any_filter |= rhs.any_filter;
2422    }
2423}
2424
2425impl FilterCharacteristics {
2426    pub fn none() -> FilterCharacteristics {
2427        FilterCharacteristics {
2428            literal_equality: false,
2429            like: false,
2430            is_null: false,
2431            literal_inequality: 0,
2432            any_filter: false,
2433        }
2434    }
2435
2436    pub fn explain(&self) -> String {
2437        let mut e = "".to_owned();
2438        if self.literal_equality {
2439            e.push_str("e");
2440        }
2441        if self.like {
2442            e.push_str("l");
2443        }
2444        if self.is_null {
2445            e.push_str("n");
2446        }
2447        for _ in 0..self.literal_inequality {
2448            e.push_str("i");
2449        }
2450        if self.any_filter {
2451            e.push_str("f");
2452        }
2453        e
2454    }
2455
2456    pub fn filter_characteristics(
2457        filters: &Vec<MirScalarExpr>,
2458    ) -> Result<FilterCharacteristics, RecursionLimitError> {
2459        let mut literal_equality = false;
2460        let mut like = false;
2461        let mut is_null = false;
2462        let mut literal_inequality = 0;
2463        let mut any_filter = false;
2464        filters.iter().try_for_each(|f| {
2465            let mut literal_inequality_in_current_filter = false;
2466            let mut is_not_null_in_current_filter = false;
2467            f.visit_pre_with_context(
2468                false,
2469                &mut |not_in_parent_chain, expr| {
2470                    not_in_parent_chain
2471                        || matches!(
2472                            expr,
2473                            MirScalarExpr::CallUnary {
2474                                func: UnaryFunc::Not(func::Not),
2475                                ..
2476                            }
2477                        )
2478                },
2479                &mut |not_in_parent_chain, expr| {
2480                    if !not_in_parent_chain {
2481                        if expr.any_expr_eq_literal().is_some() {
2482                            literal_equality = true;
2483                        }
2484                        if expr.any_expr_ineq_literal() {
2485                            literal_inequality_in_current_filter = true;
2486                        }
2487                        if matches!(
2488                            expr,
2489                            MirScalarExpr::CallUnary {
2490                                func: UnaryFunc::IsLikeMatch(_),
2491                                ..
2492                            }
2493                        ) {
2494                            like = true;
2495                        }
2496                    };
2497                    if matches!(
2498                        expr,
2499                        MirScalarExpr::CallUnary {
2500                            func: UnaryFunc::IsNull(crate::func::IsNull),
2501                            ..
2502                        }
2503                    ) {
2504                        if *not_in_parent_chain {
2505                            is_not_null_in_current_filter = true;
2506                        } else {
2507                            is_null = true;
2508                        }
2509                    }
2510                },
2511            )?;
2512            if literal_inequality_in_current_filter {
2513                literal_inequality += 1;
2514            }
2515            if !is_not_null_in_current_filter {
2516                // We want to ignore `IS NOT NULL` for `any_filter`.
2517                any_filter = true;
2518            }
2519            Ok(())
2520        })?;
2521        Ok(FilterCharacteristics {
2522            literal_equality,
2523            like,
2524            is_null,
2525            literal_inequality,
2526            any_filter,
2527        })
2528    }
2529
2530    pub fn add_literal_equality(&mut self) {
2531        self.literal_equality = true;
2532    }
2533
2534    pub fn worst_case_scaling_factor(&self) -> f64 {
2535        let mut factor = 1.0;
2536
2537        if self.literal_equality {
2538            factor *= 0.1;
2539        }
2540
2541        if self.is_null {
2542            factor *= 0.1;
2543        }
2544
2545        if self.literal_inequality >= 2 {
2546            factor *= 0.25;
2547        } else if self.literal_inequality == 1 {
2548            factor *= 0.33;
2549        }
2550
2551        // catch various negated filters, treat them pessimistically
2552        if !(self.literal_equality || self.is_null || self.literal_inequality > 0)
2553            && self.any_filter
2554        {
2555            factor *= 0.9;
2556        }
2557
2558        factor
2559    }
2560}
2561
2562#[derive(
2563    Arbitrary,
2564    Ord,
2565    PartialOrd,
2566    Copy,
2567    Clone,
2568    Debug,
2569    Eq,
2570    PartialEq,
2571    Serialize,
2572    Deserialize,
2573    Hash,
2574    MzReflect
2575)]
2576pub enum DomainLimit {
2577    None,
2578    Inclusive(i64),
2579    Exclusive(i64),
2580}
2581
2582impl RustType<ProtoDomainLimit> for DomainLimit {
2583    fn into_proto(&self) -> ProtoDomainLimit {
2584        use proto_domain_limit::Kind::*;
2585        let kind = match self {
2586            DomainLimit::None => None(()),
2587            DomainLimit::Inclusive(v) => Inclusive(*v),
2588            DomainLimit::Exclusive(v) => Exclusive(*v),
2589        };
2590        ProtoDomainLimit { kind: Some(kind) }
2591    }
2592
2593    fn from_proto(proto: ProtoDomainLimit) -> Result<Self, TryFromProtoError> {
2594        use proto_domain_limit::Kind::*;
2595        if let Some(kind) = proto.kind {
2596            match kind {
2597                None(()) => Ok(DomainLimit::None),
2598                Inclusive(v) => Ok(DomainLimit::Inclusive(v)),
2599                Exclusive(v) => Ok(DomainLimit::Exclusive(v)),
2600            }
2601        } else {
2602            Err(TryFromProtoError::missing_field("ProtoDomainLimit::kind"))
2603        }
2604    }
2605}
2606
2607#[derive(
2608    Arbitrary,
2609    Ord,
2610    PartialOrd,
2611    Clone,
2612    Debug,
2613    Eq,
2614    PartialEq,
2615    Serialize,
2616    Deserialize,
2617    Hash,
2618    MzReflect
2619)]
2620pub enum EvalError {
2621    CharacterNotValidForEncoding(i32),
2622    CharacterTooLargeForEncoding(i32),
2623    DateBinOutOfRange(Box<str>),
2624    DivisionByZero,
2625    Unsupported {
2626        feature: Box<str>,
2627        discussion_no: Option<usize>,
2628    },
2629    FloatOverflow,
2630    FloatUnderflow,
2631    NumericFieldOverflow,
2632    Float32OutOfRange(Box<str>),
2633    Float64OutOfRange(Box<str>),
2634    Int16OutOfRange(Box<str>),
2635    Int32OutOfRange(Box<str>),
2636    Int64OutOfRange(Box<str>),
2637    UInt16OutOfRange(Box<str>),
2638    UInt32OutOfRange(Box<str>),
2639    UInt64OutOfRange(Box<str>),
2640    MzTimestampOutOfRange(Box<str>),
2641    MzTimestampStepOverflow,
2642    OidOutOfRange(Box<str>),
2643    IntervalOutOfRange(Box<str>),
2644    TimestampCannotBeNan,
2645    TimestampOutOfRange,
2646    DateOutOfRange,
2647    CharOutOfRange,
2648    IndexOutOfRange {
2649        provided: i32,
2650        // The last valid index position, i.e. `v.len() - 1`
2651        valid_end: i32,
2652    },
2653    InvalidBase64Equals,
2654    InvalidBase64Symbol(char),
2655    InvalidBase64EndSequence,
2656    InvalidTimezone(Box<str>),
2657    InvalidTimezoneInterval,
2658    InvalidTimezoneConversion,
2659    InvalidIanaTimezoneId(Box<str>),
2660    InvalidLayer {
2661        max_layer: usize,
2662        val: i64,
2663    },
2664    InvalidArray(InvalidArrayError),
2665    InvalidEncodingName(Box<str>),
2666    InvalidHashAlgorithm(Box<str>),
2667    InvalidByteSequence {
2668        byte_sequence: Box<str>,
2669        encoding_name: Box<str>,
2670    },
2671    InvalidJsonbCast {
2672        from: Box<str>,
2673        to: Box<str>,
2674    },
2675    InvalidRegex(Box<str>),
2676    InvalidRegexFlag(char),
2677    InvalidParameterValue(Box<str>),
2678    InvalidDatePart(Box<str>),
2679    KeyCannotBeNull,
2680    NegSqrt,
2681    NegLimit,
2682    NullCharacterNotPermitted,
2683    UnknownUnits(Box<str>),
2684    UnsupportedUnits(Box<str>, Box<str>),
2685    UnterminatedLikeEscapeSequence,
2686    Parse(ParseError),
2687    ParseHex(ParseHexError),
2688    Internal(Box<str>),
2689    InfinityOutOfDomain(Box<str>),
2690    NegativeOutOfDomain(Box<str>),
2691    ZeroOutOfDomain(Box<str>),
2692    OutOfDomain(DomainLimit, DomainLimit, Box<str>),
2693    ComplexOutOfRange(Box<str>),
2694    MultipleRowsFromSubquery,
2695    NegativeRowsFromSubquery,
2696    Undefined(Box<str>),
2697    LikePatternTooLong,
2698    LikeEscapeTooLong,
2699    StringValueTooLong {
2700        target_type: Box<str>,
2701        length: usize,
2702    },
2703    MultidimensionalArrayRemovalNotSupported,
2704    IncompatibleArrayDimensions {
2705        dims: Option<(usize, usize)>,
2706    },
2707    TypeFromOid(Box<str>),
2708    InvalidRange(InvalidRangeError),
2709    InvalidRoleId(Box<str>),
2710    InvalidPrivileges(Box<str>),
2711    InvalidCatalogJson(Box<str>),
2712    LetRecLimitExceeded(Box<str>),
2713    MultiDimensionalArraySearch,
2714    MustNotBeNull(Box<str>),
2715    InvalidIdentifier {
2716        ident: Box<str>,
2717        detail: Option<Box<str>>,
2718    },
2719    ArrayFillWrongArraySubscripts,
2720    // TODO: propagate this check more widely throughout the expr crate
2721    MaxArraySizeExceeded(usize),
2722    DateDiffOverflow {
2723        unit: Box<str>,
2724        a: Box<str>,
2725        b: Box<str>,
2726    },
2727    // The error for ErrorIfNull; this should not be used in other contexts as a generic error
2728    // printer.
2729    IfNullError(Box<str>),
2730    LengthTooLarge,
2731    AclArrayNullElement,
2732    MzAclArrayNullElement,
2733    PrettyError(Box<str>),
2734    RedactError(Box<str>),
2735}
2736
2737impl fmt::Display for EvalError {
2738    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2739        match self {
2740            EvalError::CharacterNotValidForEncoding(v) => {
2741                write!(f, "requested character not valid for encoding: {v}")
2742            }
2743            EvalError::CharacterTooLargeForEncoding(v) => {
2744                write!(f, "requested character too large for encoding: {v}")
2745            }
2746            EvalError::DateBinOutOfRange(message) => f.write_str(message),
2747            EvalError::DivisionByZero => f.write_str("division by zero"),
2748            EvalError::Unsupported {
2749                feature,
2750                discussion_no,
2751            } => {
2752                write!(f, "{} not yet supported", feature)?;
2753                if let Some(discussion_no) = discussion_no {
2754                    write!(
2755                        f,
2756                        ", see https://github.com/MaterializeInc/materialize/discussions/{} for more details",
2757                        discussion_no
2758                    )?;
2759                }
2760                Ok(())
2761            }
2762            EvalError::FloatOverflow => f.write_str("value out of range: overflow"),
2763            EvalError::FloatUnderflow => f.write_str("value out of range: underflow"),
2764            EvalError::NumericFieldOverflow => f.write_str("numeric field overflow"),
2765            EvalError::Float32OutOfRange(val) => write!(f, "{} real out of range", val.quoted()),
2766            EvalError::Float64OutOfRange(val) => {
2767                write!(f, "{} double precision out of range", val.quoted())
2768            }
2769            EvalError::Int16OutOfRange(val) => write!(f, "{} smallint out of range", val.quoted()),
2770            EvalError::Int32OutOfRange(val) => write!(f, "{} integer out of range", val.quoted()),
2771            EvalError::Int64OutOfRange(val) => write!(f, "{} bigint out of range", val.quoted()),
2772            EvalError::UInt16OutOfRange(val) => write!(f, "{} uint2 out of range", val.quoted()),
2773            EvalError::UInt32OutOfRange(val) => write!(f, "{} uint4 out of range", val.quoted()),
2774            EvalError::UInt64OutOfRange(val) => write!(f, "{} uint8 out of range", val.quoted()),
2775            EvalError::MzTimestampOutOfRange(val) => {
2776                write!(f, "{} mz_timestamp out of range", val.quoted())
2777            }
2778            EvalError::MzTimestampStepOverflow => f.write_str("step mz_timestamp overflow"),
2779            EvalError::OidOutOfRange(val) => write!(f, "{} OID out of range", val.quoted()),
2780            EvalError::IntervalOutOfRange(val) => {
2781                write!(f, "{} interval out of range", val.quoted())
2782            }
2783            EvalError::TimestampCannotBeNan => f.write_str("timestamp cannot be NaN"),
2784            EvalError::TimestampOutOfRange => f.write_str("timestamp out of range"),
2785            EvalError::DateOutOfRange => f.write_str("date out of range"),
2786            EvalError::CharOutOfRange => f.write_str("\"char\" out of range"),
2787            EvalError::IndexOutOfRange {
2788                provided,
2789                valid_end,
2790            } => write!(f, "index {provided} out of valid range, 0..{valid_end}",),
2791            EvalError::InvalidBase64Equals => {
2792                f.write_str("unexpected \"=\" while decoding base64 sequence")
2793            }
2794            EvalError::InvalidBase64Symbol(c) => write!(
2795                f,
2796                "invalid symbol \"{}\" found while decoding base64 sequence",
2797                c.escape_default()
2798            ),
2799            EvalError::InvalidBase64EndSequence => f.write_str("invalid base64 end sequence"),
2800            EvalError::InvalidJsonbCast { from, to } => {
2801                write!(f, "cannot cast jsonb {} to type {}", from, to)
2802            }
2803            EvalError::InvalidTimezone(tz) => write!(f, "invalid time zone '{}'", tz),
2804            EvalError::InvalidTimezoneInterval => {
2805                f.write_str("timezone interval must not contain months or years")
2806            }
2807            EvalError::InvalidTimezoneConversion => f.write_str("invalid timezone conversion"),
2808            EvalError::InvalidIanaTimezoneId(tz) => {
2809                write!(f, "invalid IANA Time Zone Database identifier: '{}'", tz)
2810            }
2811            EvalError::InvalidLayer { max_layer, val } => write!(
2812                f,
2813                "invalid layer: {}; must use value within [1, {}]",
2814                val, max_layer
2815            ),
2816            EvalError::InvalidArray(e) => e.fmt(f),
2817            EvalError::InvalidEncodingName(name) => write!(f, "invalid encoding name '{}'", name),
2818            EvalError::InvalidHashAlgorithm(alg) => write!(f, "invalid hash algorithm '{}'", alg),
2819            EvalError::InvalidByteSequence {
2820                byte_sequence,
2821                encoding_name,
2822            } => write!(
2823                f,
2824                "invalid byte sequence '{}' for encoding '{}'",
2825                byte_sequence, encoding_name
2826            ),
2827            EvalError::InvalidDatePart(part) => write!(f, "invalid datepart {}", part.quoted()),
2828            EvalError::KeyCannotBeNull => f.write_str("key cannot be null"),
2829            EvalError::NegSqrt => f.write_str("cannot take square root of a negative number"),
2830            EvalError::NegLimit => f.write_str("LIMIT must not be negative"),
2831            EvalError::NullCharacterNotPermitted => f.write_str("null character not permitted"),
2832            EvalError::InvalidRegex(e) => write!(f, "invalid regular expression: {}", e),
2833            EvalError::InvalidRegexFlag(c) => write!(f, "invalid regular expression flag: {}", c),
2834            EvalError::InvalidParameterValue(s) => f.write_str(s),
2835            EvalError::UnknownUnits(units) => write!(f, "unit '{}' not recognized", units),
2836            EvalError::UnsupportedUnits(units, typ) => {
2837                write!(f, "unit '{}' not supported for type {}", units, typ)
2838            }
2839            EvalError::UnterminatedLikeEscapeSequence => {
2840                f.write_str("unterminated escape sequence in LIKE")
2841            }
2842            EvalError::Parse(e) => e.fmt(f),
2843            EvalError::PrettyError(e) => e.fmt(f),
2844            EvalError::RedactError(e) => e.fmt(f),
2845            EvalError::ParseHex(e) => e.fmt(f),
2846            EvalError::Internal(s) => write!(f, "internal error: {}", s),
2847            EvalError::InfinityOutOfDomain(s) => {
2848                write!(f, "function {} is only defined for finite arguments", s)
2849            }
2850            EvalError::NegativeOutOfDomain(s) => {
2851                write!(f, "function {} is not defined for negative numbers", s)
2852            }
2853            EvalError::ZeroOutOfDomain(s) => {
2854                write!(f, "function {} is not defined for zero", s)
2855            }
2856            EvalError::OutOfDomain(lower, upper, s) => {
2857                use DomainLimit::*;
2858                write!(f, "function {s} is defined for numbers ")?;
2859                match (lower, upper) {
2860                    (Inclusive(n), None) => write!(f, "greater than or equal to {n}"),
2861                    (Exclusive(n), None) => write!(f, "greater than {n}"),
2862                    (None, Inclusive(n)) => write!(f, "less than or equal to {n}"),
2863                    (None, Exclusive(n)) => write!(f, "less than {n}"),
2864                    (Inclusive(lo), Inclusive(hi)) => write!(f, "between {lo} and {hi} inclusive"),
2865                    (Exclusive(lo), Exclusive(hi)) => write!(f, "between {lo} and {hi} exclusive"),
2866                    (Inclusive(lo), Exclusive(hi)) => {
2867                        write!(f, "between {lo} inclusive and {hi} exclusive")
2868                    }
2869                    (Exclusive(lo), Inclusive(hi)) => {
2870                        write!(f, "between {lo} exclusive and {hi} inclusive")
2871                    }
2872                    (None, None) => panic!("invalid domain error"),
2873                }
2874            }
2875            EvalError::ComplexOutOfRange(s) => {
2876                write!(f, "function {} cannot return complex numbers", s)
2877            }
2878            EvalError::MultipleRowsFromSubquery => {
2879                write!(f, "more than one record produced in subquery")
2880            }
2881            EvalError::NegativeRowsFromSubquery => {
2882                write!(f, "negative number of rows produced in subquery")
2883            }
2884            EvalError::Undefined(s) => {
2885                write!(f, "{} is undefined", s)
2886            }
2887            EvalError::LikePatternTooLong => {
2888                write!(f, "LIKE pattern exceeds maximum length")
2889            }
2890            EvalError::LikeEscapeTooLong => {
2891                write!(f, "invalid escape string")
2892            }
2893            EvalError::StringValueTooLong {
2894                target_type,
2895                length,
2896            } => {
2897                write!(f, "value too long for type {}({})", target_type, length)
2898            }
2899            EvalError::MultidimensionalArrayRemovalNotSupported => {
2900                write!(
2901                    f,
2902                    "removing elements from multidimensional arrays is not supported"
2903                )
2904            }
2905            EvalError::IncompatibleArrayDimensions { dims: _ } => {
2906                write!(f, "cannot concatenate incompatible arrays")
2907            }
2908            EvalError::TypeFromOid(msg) => write!(f, "{msg}"),
2909            EvalError::InvalidRange(e) => e.fmt(f),
2910            EvalError::InvalidRoleId(msg) => write!(f, "{msg}"),
2911            EvalError::InvalidPrivileges(privilege) => {
2912                write!(f, "unrecognized privilege type: {privilege}")
2913            }
2914            EvalError::InvalidCatalogJson(msg) => {
2915                write!(f, "invalid catalog JSON: {msg}")
2916            }
2917            EvalError::LetRecLimitExceeded(max_iters) => {
2918                write!(
2919                    f,
2920                    "Recursive query exceeded the recursion limit {}. (Use RETURN AT RECURSION LIMIT to not error, but return the current state as the final result when reaching the limit.)",
2921                    max_iters
2922                )
2923            }
2924            EvalError::MultiDimensionalArraySearch => write!(
2925                f,
2926                "searching for elements in multidimensional arrays is not supported"
2927            ),
2928            EvalError::MustNotBeNull(v) => write!(f, "{v} must not be null"),
2929            EvalError::InvalidIdentifier { ident, .. } => {
2930                write!(f, "string is not a valid identifier: {}", ident.quoted())
2931            }
2932            EvalError::ArrayFillWrongArraySubscripts => {
2933                f.write_str("wrong number of array subscripts")
2934            }
2935            EvalError::MaxArraySizeExceeded(max_size) => {
2936                write!(
2937                    f,
2938                    "array size exceeds the maximum allowed ({max_size} bytes)"
2939                )
2940            }
2941            EvalError::DateDiffOverflow { unit, a, b } => {
2942                write!(f, "datediff overflow, {unit} of {a}, {b}")
2943            }
2944            EvalError::IfNullError(s) => f.write_str(s),
2945            EvalError::LengthTooLarge => write!(f, "requested length too large"),
2946            EvalError::AclArrayNullElement => write!(f, "ACL arrays must not contain null values"),
2947            EvalError::MzAclArrayNullElement => {
2948                write!(f, "MZ_ACL arrays must not contain null values")
2949            }
2950        }
2951    }
2952}
2953
2954impl EvalError {
2955    pub fn detail(&self) -> Option<String> {
2956        match self {
2957            EvalError::IncompatibleArrayDimensions { dims: None } => Some(
2958                "Arrays with differing dimensions are not compatible for concatenation.".into(),
2959            ),
2960            EvalError::IncompatibleArrayDimensions {
2961                dims: Some((a_dims, b_dims)),
2962            } => Some(format!(
2963                "Arrays of {} and {} dimensions are not compatible for concatenation.",
2964                a_dims, b_dims
2965            )),
2966            EvalError::InvalidIdentifier { detail, .. } => detail.as_deref().map(Into::into),
2967            EvalError::ArrayFillWrongArraySubscripts => {
2968                Some("Low bound array has different size than dimensions array.".into())
2969            }
2970            _ => None,
2971        }
2972    }
2973
2974    pub fn hint(&self) -> Option<String> {
2975        match self {
2976            EvalError::InvalidBase64EndSequence => Some(
2977                "Input data is missing padding, is truncated, or is otherwise corrupted.".into(),
2978            ),
2979            EvalError::LikeEscapeTooLong => {
2980                Some("Escape string must be empty or one character.".into())
2981            }
2982            EvalError::MzTimestampOutOfRange(_) => Some(
2983                "Integer, numeric, and text casts to mz_timestamp must be in the form of whole \
2984                milliseconds since the Unix epoch. Values with fractional parts cannot be \
2985                converted to mz_timestamp."
2986                    .into(),
2987            ),
2988            _ => None,
2989        }
2990    }
2991}
2992
2993impl std::error::Error for EvalError {}
2994
2995impl From<ParseError> for EvalError {
2996    fn from(e: ParseError) -> EvalError {
2997        EvalError::Parse(e)
2998    }
2999}
3000
3001impl From<ParseHexError> for EvalError {
3002    fn from(e: ParseHexError) -> EvalError {
3003        EvalError::ParseHex(e)
3004    }
3005}
3006
3007impl From<InvalidArrayError> for EvalError {
3008    fn from(e: InvalidArrayError) -> EvalError {
3009        EvalError::InvalidArray(e)
3010    }
3011}
3012
3013impl From<RegexCompilationError> for EvalError {
3014    fn from(e: RegexCompilationError) -> EvalError {
3015        EvalError::InvalidRegex(e.to_string().into())
3016    }
3017}
3018
3019impl From<TypeFromOidError> for EvalError {
3020    fn from(e: TypeFromOidError) -> EvalError {
3021        EvalError::TypeFromOid(e.to_string().into())
3022    }
3023}
3024
3025impl From<DateError> for EvalError {
3026    fn from(e: DateError) -> EvalError {
3027        match e {
3028            DateError::OutOfRange => EvalError::DateOutOfRange,
3029        }
3030    }
3031}
3032
3033impl From<TimestampError> for EvalError {
3034    fn from(e: TimestampError) -> EvalError {
3035        match e {
3036            TimestampError::OutOfRange => EvalError::TimestampOutOfRange,
3037        }
3038    }
3039}
3040
3041impl From<InvalidRangeError> for EvalError {
3042    fn from(e: InvalidRangeError) -> EvalError {
3043        EvalError::InvalidRange(e)
3044    }
3045}
3046
3047impl RustType<ProtoEvalError> for EvalError {
3048    fn into_proto(&self) -> ProtoEvalError {
3049        use proto_eval_error::Kind::*;
3050        use proto_eval_error::*;
3051        let kind = match self {
3052            EvalError::CharacterNotValidForEncoding(v) => CharacterNotValidForEncoding(*v),
3053            EvalError::CharacterTooLargeForEncoding(v) => CharacterTooLargeForEncoding(*v),
3054            EvalError::DateBinOutOfRange(v) => DateBinOutOfRange(v.into_proto()),
3055            EvalError::DivisionByZero => DivisionByZero(()),
3056            EvalError::Unsupported {
3057                feature,
3058                discussion_no,
3059            } => Unsupported(ProtoUnsupported {
3060                feature: feature.into_proto(),
3061                discussion_no: discussion_no.into_proto(),
3062            }),
3063            EvalError::FloatOverflow => FloatOverflow(()),
3064            EvalError::FloatUnderflow => FloatUnderflow(()),
3065            EvalError::NumericFieldOverflow => NumericFieldOverflow(()),
3066            EvalError::Float32OutOfRange(val) => Float32OutOfRange(ProtoValueOutOfRange {
3067                value: val.to_string(),
3068            }),
3069            EvalError::Float64OutOfRange(val) => Float64OutOfRange(ProtoValueOutOfRange {
3070                value: val.to_string(),
3071            }),
3072            EvalError::Int16OutOfRange(val) => Int16OutOfRange(ProtoValueOutOfRange {
3073                value: val.to_string(),
3074            }),
3075            EvalError::Int32OutOfRange(val) => Int32OutOfRange(ProtoValueOutOfRange {
3076                value: val.to_string(),
3077            }),
3078            EvalError::Int64OutOfRange(val) => Int64OutOfRange(ProtoValueOutOfRange {
3079                value: val.to_string(),
3080            }),
3081            EvalError::UInt16OutOfRange(val) => Uint16OutOfRange(ProtoValueOutOfRange {
3082                value: val.to_string(),
3083            }),
3084            EvalError::UInt32OutOfRange(val) => Uint32OutOfRange(ProtoValueOutOfRange {
3085                value: val.to_string(),
3086            }),
3087            EvalError::UInt64OutOfRange(val) => Uint64OutOfRange(ProtoValueOutOfRange {
3088                value: val.to_string(),
3089            }),
3090            EvalError::MzTimestampOutOfRange(val) => MzTimestampOutOfRange(ProtoValueOutOfRange {
3091                value: val.to_string(),
3092            }),
3093            EvalError::MzTimestampStepOverflow => MzTimestampStepOverflow(()),
3094            EvalError::OidOutOfRange(val) => OidOutOfRange(ProtoValueOutOfRange {
3095                value: val.to_string(),
3096            }),
3097            EvalError::IntervalOutOfRange(val) => IntervalOutOfRange(ProtoValueOutOfRange {
3098                value: val.to_string(),
3099            }),
3100            EvalError::TimestampCannotBeNan => TimestampCannotBeNan(()),
3101            EvalError::TimestampOutOfRange => TimestampOutOfRange(()),
3102            EvalError::DateOutOfRange => DateOutOfRange(()),
3103            EvalError::CharOutOfRange => CharOutOfRange(()),
3104            EvalError::IndexOutOfRange {
3105                provided,
3106                valid_end,
3107            } => IndexOutOfRange(ProtoIndexOutOfRange {
3108                provided: *provided,
3109                valid_end: *valid_end,
3110            }),
3111            EvalError::InvalidBase64Equals => InvalidBase64Equals(()),
3112            EvalError::InvalidBase64Symbol(sym) => InvalidBase64Symbol(sym.into_proto()),
3113            EvalError::InvalidBase64EndSequence => InvalidBase64EndSequence(()),
3114            EvalError::InvalidTimezone(tz) => InvalidTimezone(tz.into_proto()),
3115            EvalError::InvalidTimezoneInterval => InvalidTimezoneInterval(()),
3116            EvalError::InvalidTimezoneConversion => InvalidTimezoneConversion(()),
3117            EvalError::InvalidLayer { max_layer, val } => InvalidLayer(ProtoInvalidLayer {
3118                max_layer: max_layer.into_proto(),
3119                val: *val,
3120            }),
3121            EvalError::InvalidArray(error) => InvalidArray(error.into_proto()),
3122            EvalError::InvalidEncodingName(v) => InvalidEncodingName(v.into_proto()),
3123            EvalError::InvalidHashAlgorithm(v) => InvalidHashAlgorithm(v.into_proto()),
3124            EvalError::InvalidByteSequence {
3125                byte_sequence,
3126                encoding_name,
3127            } => InvalidByteSequence(ProtoInvalidByteSequence {
3128                byte_sequence: byte_sequence.into_proto(),
3129                encoding_name: encoding_name.into_proto(),
3130            }),
3131            EvalError::InvalidJsonbCast { from, to } => InvalidJsonbCast(ProtoInvalidJsonbCast {
3132                from: from.into_proto(),
3133                to: to.into_proto(),
3134            }),
3135            EvalError::InvalidRegex(v) => InvalidRegex(v.into_proto()),
3136            EvalError::InvalidRegexFlag(v) => InvalidRegexFlag(v.into_proto()),
3137            EvalError::InvalidParameterValue(v) => InvalidParameterValue(v.into_proto()),
3138            EvalError::InvalidDatePart(part) => InvalidDatePart(part.into_proto()),
3139            EvalError::KeyCannotBeNull => KeyCannotBeNull(()),
3140            EvalError::NegSqrt => NegSqrt(()),
3141            EvalError::NegLimit => NegLimit(()),
3142            EvalError::NullCharacterNotPermitted => NullCharacterNotPermitted(()),
3143            EvalError::UnknownUnits(v) => UnknownUnits(v.into_proto()),
3144            EvalError::UnsupportedUnits(units, typ) => UnsupportedUnits(ProtoUnsupportedUnits {
3145                units: units.into_proto(),
3146                typ: typ.into_proto(),
3147            }),
3148            EvalError::UnterminatedLikeEscapeSequence => UnterminatedLikeEscapeSequence(()),
3149            EvalError::Parse(error) => Parse(error.into_proto()),
3150            EvalError::PrettyError(error) => PrettyError(error.into_proto()),
3151            EvalError::RedactError(error) => RedactError(error.into_proto()),
3152            EvalError::ParseHex(error) => ParseHex(error.into_proto()),
3153            EvalError::Internal(v) => Internal(v.into_proto()),
3154            EvalError::InfinityOutOfDomain(v) => InfinityOutOfDomain(v.into_proto()),
3155            EvalError::NegativeOutOfDomain(v) => NegativeOutOfDomain(v.into_proto()),
3156            EvalError::ZeroOutOfDomain(v) => ZeroOutOfDomain(v.into_proto()),
3157            EvalError::OutOfDomain(lower, upper, id) => OutOfDomain(ProtoOutOfDomain {
3158                lower: Some(lower.into_proto()),
3159                upper: Some(upper.into_proto()),
3160                id: id.into_proto(),
3161            }),
3162            EvalError::ComplexOutOfRange(v) => ComplexOutOfRange(v.into_proto()),
3163            EvalError::MultipleRowsFromSubquery => MultipleRowsFromSubquery(()),
3164            EvalError::NegativeRowsFromSubquery => NegativeRowsFromSubquery(()),
3165            EvalError::Undefined(v) => Undefined(v.into_proto()),
3166            EvalError::LikePatternTooLong => LikePatternTooLong(()),
3167            EvalError::LikeEscapeTooLong => LikeEscapeTooLong(()),
3168            EvalError::StringValueTooLong {
3169                target_type,
3170                length,
3171            } => StringValueTooLong(ProtoStringValueTooLong {
3172                target_type: target_type.into_proto(),
3173                length: length.into_proto(),
3174            }),
3175            EvalError::MultidimensionalArrayRemovalNotSupported => {
3176                MultidimensionalArrayRemovalNotSupported(())
3177            }
3178            EvalError::IncompatibleArrayDimensions { dims } => {
3179                IncompatibleArrayDimensions(ProtoIncompatibleArrayDimensions {
3180                    dims: dims.into_proto(),
3181                })
3182            }
3183            EvalError::TypeFromOid(v) => TypeFromOid(v.into_proto()),
3184            EvalError::InvalidRange(error) => InvalidRange(error.into_proto()),
3185            EvalError::InvalidRoleId(v) => InvalidRoleId(v.into_proto()),
3186            EvalError::InvalidPrivileges(v) => InvalidPrivileges(v.into_proto()),
3187            EvalError::InvalidCatalogJson(v) => InvalidCatalogJson(v.into_proto()),
3188            EvalError::LetRecLimitExceeded(v) => WmrRecursionLimitExceeded(v.into_proto()),
3189            EvalError::MultiDimensionalArraySearch => MultiDimensionalArraySearch(()),
3190            EvalError::MustNotBeNull(v) => MustNotBeNull(v.into_proto()),
3191            EvalError::InvalidIdentifier { ident, detail } => {
3192                InvalidIdentifier(ProtoInvalidIdentifier {
3193                    ident: ident.into_proto(),
3194                    detail: detail.into_proto(),
3195                })
3196            }
3197            EvalError::ArrayFillWrongArraySubscripts => ArrayFillWrongArraySubscripts(()),
3198            EvalError::MaxArraySizeExceeded(max_size) => {
3199                MaxArraySizeExceeded(u64::cast_from(*max_size))
3200            }
3201            EvalError::DateDiffOverflow { unit, a, b } => DateDiffOverflow(ProtoDateDiffOverflow {
3202                unit: unit.into_proto(),
3203                a: a.into_proto(),
3204                b: b.into_proto(),
3205            }),
3206            EvalError::IfNullError(s) => IfNullError(s.into_proto()),
3207            EvalError::LengthTooLarge => LengthTooLarge(()),
3208            EvalError::AclArrayNullElement => AclArrayNullElement(()),
3209            EvalError::MzAclArrayNullElement => MzAclArrayNullElement(()),
3210            EvalError::InvalidIanaTimezoneId(s) => InvalidIanaTimezoneId(s.into_proto()),
3211        };
3212        ProtoEvalError { kind: Some(kind) }
3213    }
3214
3215    fn from_proto(proto: ProtoEvalError) -> Result<Self, TryFromProtoError> {
3216        use proto_eval_error::Kind::*;
3217        match proto.kind {
3218            Some(kind) => match kind {
3219                CharacterNotValidForEncoding(v) => Ok(EvalError::CharacterNotValidForEncoding(v)),
3220                CharacterTooLargeForEncoding(v) => Ok(EvalError::CharacterTooLargeForEncoding(v)),
3221                DateBinOutOfRange(v) => Ok(EvalError::DateBinOutOfRange(v.into())),
3222                DivisionByZero(()) => Ok(EvalError::DivisionByZero),
3223                Unsupported(v) => Ok(EvalError::Unsupported {
3224                    feature: v.feature.into(),
3225                    discussion_no: v.discussion_no.into_rust()?,
3226                }),
3227                FloatOverflow(()) => Ok(EvalError::FloatOverflow),
3228                FloatUnderflow(()) => Ok(EvalError::FloatUnderflow),
3229                NumericFieldOverflow(()) => Ok(EvalError::NumericFieldOverflow),
3230                Float32OutOfRange(val) => Ok(EvalError::Float32OutOfRange(val.value.into())),
3231                Float64OutOfRange(val) => Ok(EvalError::Float64OutOfRange(val.value.into())),
3232                Int16OutOfRange(val) => Ok(EvalError::Int16OutOfRange(val.value.into())),
3233                Int32OutOfRange(val) => Ok(EvalError::Int32OutOfRange(val.value.into())),
3234                Int64OutOfRange(val) => Ok(EvalError::Int64OutOfRange(val.value.into())),
3235                Uint16OutOfRange(val) => Ok(EvalError::UInt16OutOfRange(val.value.into())),
3236                Uint32OutOfRange(val) => Ok(EvalError::UInt32OutOfRange(val.value.into())),
3237                Uint64OutOfRange(val) => Ok(EvalError::UInt64OutOfRange(val.value.into())),
3238                MzTimestampOutOfRange(val) => {
3239                    Ok(EvalError::MzTimestampOutOfRange(val.value.into()))
3240                }
3241                MzTimestampStepOverflow(()) => Ok(EvalError::MzTimestampStepOverflow),
3242                OidOutOfRange(val) => Ok(EvalError::OidOutOfRange(val.value.into())),
3243                IntervalOutOfRange(val) => Ok(EvalError::IntervalOutOfRange(val.value.into())),
3244                TimestampCannotBeNan(()) => Ok(EvalError::TimestampCannotBeNan),
3245                TimestampOutOfRange(()) => Ok(EvalError::TimestampOutOfRange),
3246                DateOutOfRange(()) => Ok(EvalError::DateOutOfRange),
3247                CharOutOfRange(()) => Ok(EvalError::CharOutOfRange),
3248                IndexOutOfRange(v) => Ok(EvalError::IndexOutOfRange {
3249                    provided: v.provided,
3250                    valid_end: v.valid_end,
3251                }),
3252                InvalidBase64Equals(()) => Ok(EvalError::InvalidBase64Equals),
3253                InvalidBase64Symbol(v) => char::from_proto(v).map(EvalError::InvalidBase64Symbol),
3254                InvalidBase64EndSequence(()) => Ok(EvalError::InvalidBase64EndSequence),
3255                InvalidTimezone(v) => Ok(EvalError::InvalidTimezone(v.into())),
3256                InvalidTimezoneInterval(()) => Ok(EvalError::InvalidTimezoneInterval),
3257                InvalidTimezoneConversion(()) => Ok(EvalError::InvalidTimezoneConversion),
3258                InvalidLayer(v) => Ok(EvalError::InvalidLayer {
3259                    max_layer: usize::from_proto(v.max_layer)?,
3260                    val: v.val,
3261                }),
3262                InvalidArray(error) => Ok(EvalError::InvalidArray(error.into_rust()?)),
3263                InvalidEncodingName(v) => Ok(EvalError::InvalidEncodingName(v.into())),
3264                InvalidHashAlgorithm(v) => Ok(EvalError::InvalidHashAlgorithm(v.into())),
3265                InvalidByteSequence(v) => Ok(EvalError::InvalidByteSequence {
3266                    byte_sequence: v.byte_sequence.into(),
3267                    encoding_name: v.encoding_name.into(),
3268                }),
3269                InvalidJsonbCast(v) => Ok(EvalError::InvalidJsonbCast {
3270                    from: v.from.into(),
3271                    to: v.to.into(),
3272                }),
3273                InvalidRegex(v) => Ok(EvalError::InvalidRegex(v.into())),
3274                InvalidRegexFlag(v) => Ok(EvalError::InvalidRegexFlag(char::from_proto(v)?)),
3275                InvalidParameterValue(v) => Ok(EvalError::InvalidParameterValue(v.into())),
3276                InvalidDatePart(part) => Ok(EvalError::InvalidDatePart(part.into())),
3277                KeyCannotBeNull(()) => Ok(EvalError::KeyCannotBeNull),
3278                NegSqrt(()) => Ok(EvalError::NegSqrt),
3279                NegLimit(()) => Ok(EvalError::NegLimit),
3280                NullCharacterNotPermitted(()) => Ok(EvalError::NullCharacterNotPermitted),
3281                UnknownUnits(v) => Ok(EvalError::UnknownUnits(v.into())),
3282                UnsupportedUnits(v) => {
3283                    Ok(EvalError::UnsupportedUnits(v.units.into(), v.typ.into()))
3284                }
3285                UnterminatedLikeEscapeSequence(()) => Ok(EvalError::UnterminatedLikeEscapeSequence),
3286                Parse(error) => Ok(EvalError::Parse(error.into_rust()?)),
3287                ParseHex(error) => Ok(EvalError::ParseHex(error.into_rust()?)),
3288                Internal(v) => Ok(EvalError::Internal(v.into())),
3289                InfinityOutOfDomain(v) => Ok(EvalError::InfinityOutOfDomain(v.into())),
3290                NegativeOutOfDomain(v) => Ok(EvalError::NegativeOutOfDomain(v.into())),
3291                ZeroOutOfDomain(v) => Ok(EvalError::ZeroOutOfDomain(v.into())),
3292                OutOfDomain(v) => Ok(EvalError::OutOfDomain(
3293                    v.lower.into_rust_if_some("ProtoDomainLimit::lower")?,
3294                    v.upper.into_rust_if_some("ProtoDomainLimit::upper")?,
3295                    v.id.into(),
3296                )),
3297                ComplexOutOfRange(v) => Ok(EvalError::ComplexOutOfRange(v.into())),
3298                MultipleRowsFromSubquery(()) => Ok(EvalError::MultipleRowsFromSubquery),
3299                NegativeRowsFromSubquery(()) => Ok(EvalError::NegativeRowsFromSubquery),
3300                Undefined(v) => Ok(EvalError::Undefined(v.into())),
3301                LikePatternTooLong(()) => Ok(EvalError::LikePatternTooLong),
3302                LikeEscapeTooLong(()) => Ok(EvalError::LikeEscapeTooLong),
3303                StringValueTooLong(v) => Ok(EvalError::StringValueTooLong {
3304                    target_type: v.target_type.into(),
3305                    length: usize::from_proto(v.length)?,
3306                }),
3307                MultidimensionalArrayRemovalNotSupported(()) => {
3308                    Ok(EvalError::MultidimensionalArrayRemovalNotSupported)
3309                }
3310                IncompatibleArrayDimensions(v) => Ok(EvalError::IncompatibleArrayDimensions {
3311                    dims: v.dims.into_rust()?,
3312                }),
3313                TypeFromOid(v) => Ok(EvalError::TypeFromOid(v.into())),
3314                InvalidRange(e) => Ok(EvalError::InvalidRange(e.into_rust()?)),
3315                InvalidRoleId(v) => Ok(EvalError::InvalidRoleId(v.into())),
3316                InvalidPrivileges(v) => Ok(EvalError::InvalidPrivileges(v.into())),
3317                InvalidCatalogJson(v) => Ok(EvalError::InvalidCatalogJson(v.into())),
3318                WmrRecursionLimitExceeded(v) => Ok(EvalError::LetRecLimitExceeded(v.into())),
3319                MultiDimensionalArraySearch(()) => Ok(EvalError::MultiDimensionalArraySearch),
3320                MustNotBeNull(v) => Ok(EvalError::MustNotBeNull(v.into())),
3321                InvalidIdentifier(v) => Ok(EvalError::InvalidIdentifier {
3322                    ident: v.ident.into(),
3323                    detail: v.detail.into_rust()?,
3324                }),
3325                ArrayFillWrongArraySubscripts(()) => Ok(EvalError::ArrayFillWrongArraySubscripts),
3326                MaxArraySizeExceeded(max_size) => {
3327                    Ok(EvalError::MaxArraySizeExceeded(usize::cast_from(max_size)))
3328                }
3329                DateDiffOverflow(v) => Ok(EvalError::DateDiffOverflow {
3330                    unit: v.unit.into(),
3331                    a: v.a.into(),
3332                    b: v.b.into(),
3333                }),
3334                IfNullError(v) => Ok(EvalError::IfNullError(v.into())),
3335                LengthTooLarge(()) => Ok(EvalError::LengthTooLarge),
3336                AclArrayNullElement(()) => Ok(EvalError::AclArrayNullElement),
3337                MzAclArrayNullElement(()) => Ok(EvalError::MzAclArrayNullElement),
3338                InvalidIanaTimezoneId(s) => Ok(EvalError::InvalidIanaTimezoneId(s.into())),
3339                PrettyError(s) => Ok(EvalError::PrettyError(s.into())),
3340                RedactError(s) => Ok(EvalError::RedactError(s.into())),
3341            },
3342            None => Err(TryFromProtoError::missing_field("ProtoEvalError::kind")),
3343        }
3344    }
3345}
3346
3347impl RustType<ProtoDims> for (usize, usize) {
3348    fn into_proto(&self) -> ProtoDims {
3349        ProtoDims {
3350            f0: self.0.into_proto(),
3351            f1: self.1.into_proto(),
3352        }
3353    }
3354
3355    fn from_proto(proto: ProtoDims) -> Result<Self, TryFromProtoError> {
3356        Ok((proto.f0.into_rust()?, proto.f1.into_rust()?))
3357    }
3358}
3359
3360#[cfg(test)]
3361mod tests {
3362    use super::*;
3363
3364    #[mz_ore::test]
3365    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
3366    fn test_reduce() {
3367        let relation_type: Vec<ReprColumnType> = vec![
3368            ReprScalarType::Int64.nullable(true),
3369            ReprScalarType::Int64.nullable(true),
3370            ReprScalarType::Int64.nullable(false),
3371        ]
3372        .into_iter()
3373        .collect();
3374        let col = MirScalarExpr::column;
3375        let int64_typ = ReprScalarType::Int64;
3376        let err = |e| MirScalarExpr::literal(Err(e), int64_typ.clone());
3377        let lit = |i| MirScalarExpr::literal_ok(Datum::Int64(i), int64_typ.clone());
3378        let null = || MirScalarExpr::literal_null(int64_typ.clone());
3379
3380        struct TestCase {
3381            input: MirScalarExpr,
3382            output: MirScalarExpr,
3383        }
3384
3385        let test_cases = vec![
3386            TestCase {
3387                input: MirScalarExpr::call_variadic(Coalesce, vec![lit(1)]),
3388                output: lit(1),
3389            },
3390            TestCase {
3391                input: MirScalarExpr::call_variadic(Coalesce, vec![lit(1), lit(2)]),
3392                output: lit(1),
3393            },
3394            TestCase {
3395                input: MirScalarExpr::call_variadic(Coalesce, vec![null(), lit(2), null()]),
3396                output: lit(2),
3397            },
3398            TestCase {
3399                input: MirScalarExpr::call_variadic(
3400                    Coalesce,
3401                    vec![null(), col(0), null(), col(1), lit(2), lit(3)],
3402                ),
3403                output: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(1), lit(2)]),
3404            },
3405            TestCase {
3406                input: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(2), col(1)]),
3407                output: MirScalarExpr::call_variadic(Coalesce, vec![col(0), col(2)]),
3408            },
3409            TestCase {
3410                input: MirScalarExpr::call_variadic(
3411                    Coalesce,
3412                    vec![lit(1), err(EvalError::DivisionByZero)],
3413                ),
3414                output: lit(1),
3415            },
3416            TestCase {
3417                input: MirScalarExpr::call_variadic(
3418                    Coalesce,
3419                    vec![
3420                        null(),
3421                        err(EvalError::DivisionByZero),
3422                        err(EvalError::NumericFieldOverflow),
3423                    ],
3424                ),
3425                output: err(EvalError::DivisionByZero),
3426            },
3427        ];
3428
3429        for tc in test_cases {
3430            let mut actual = tc.input.clone();
3431            actual.reduce(&relation_type);
3432            assert!(
3433                actual == tc.output,
3434                "input: {}\nactual: {}\nexpected: {}",
3435                tc.input,
3436                actual,
3437                tc.output
3438            );
3439        }
3440    }
3441}