Skip to main content

mz_expr/
scalar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::ops::BitOrAssign;
12use std::sync::Arc;
13use std::{fmt, mem};
14
15use itertools::Itertools;
16use mz_lowertest::MzReflect;
17use mz_ore::cast::CastFrom;
18use mz_ore::collections::CollectionExt;
19use mz_ore::iter::IteratorExt;
20use mz_ore::stack::RecursionLimitError;
21use mz_ore::str::StrExt;
22use mz_ore::treat_as_equal::TreatAsEqual;
23use mz_ore::vec::swap_remove_multiple;
24use mz_pgrepr::TypeFromOidError;
25use mz_pgtz::timezone::TimezoneSpec;
26use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
27use mz_repr::adt::array::InvalidArrayError;
28use mz_repr::adt::date::DateError;
29use mz_repr::adt::datetime::DateTimeUnits;
30use mz_repr::adt::range::InvalidRangeError;
31use mz_repr::adt::regex::{Regex, RegexCompilationError};
32use mz_repr::adt::timestamp::TimestampError;
33use mz_repr::strconv::{ParseError, ParseHexError};
34use mz_repr::{Datum, ReprColumnType, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest::prelude::*;
36use proptest_derive::Arbitrary;
37use serde::{Deserialize, Serialize};
38
39use crate::scalar::func::format::DateTimeFormat;
40use crate::scalar::func::{
41    BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, parse_timezone,
42    regexp_replace_parse_flags,
43};
44use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims;
45use crate::visit::{Visit, VisitChildren};
46
47pub mod func;
48pub mod like_pattern;
49
50include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs"));
51
52#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, MzReflect)]
53pub enum MirScalarExpr {
54    /// A column of the input row
55    Column(usize, TreatAsEqual<Option<Arc<str>>>),
56    /// A literal value.
57    /// (Stored as a row, because we can't own a Datum)
58    Literal(Result<Row, EvalError>, SqlColumnType),
59    /// A call to an unmaterializable function.
60    ///
61    /// These functions cannot be evaluated by `MirScalarExpr::eval`. They must
62    /// be transformed away by a higher layer.
63    CallUnmaterializable(UnmaterializableFunc),
64    /// A function call that takes one expression as an argument.
65    CallUnary {
66        func: UnaryFunc,
67        expr: Box<MirScalarExpr>,
68    },
69    /// A function call that takes two expressions as arguments.
70    CallBinary {
71        func: BinaryFunc,
72        expr1: Box<MirScalarExpr>,
73        expr2: Box<MirScalarExpr>,
74    },
75    /// A function call that takes an arbitrary number of arguments.
76    CallVariadic {
77        func: VariadicFunc,
78        exprs: Vec<MirScalarExpr>,
79    },
80    /// Conditionally evaluated expressions.
81    ///
82    /// It is important that `then` and `els` only be evaluated if
83    /// `cond` is true or not, respectively. This is the only way
84    /// users can guard execution (other logical operator do not
85    /// short-circuit) and we need to preserve that.
86    If {
87        cond: Box<MirScalarExpr>,
88        then: Box<MirScalarExpr>,
89        els: Box<MirScalarExpr>,
90    },
91}
92
93// We need a custom Debug because we don't want to show `None` for name information.
94// Sadly, the `derivative` crate doesn't support this use case.
95impl std::fmt::Debug for MirScalarExpr {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            MirScalarExpr::Column(i, TreatAsEqual(Some(name))) => {
99                write!(f, "Column({i}, {name:?})")
100            }
101            MirScalarExpr::Column(i, TreatAsEqual(None)) => write!(f, "Column({i})"),
102            MirScalarExpr::Literal(lit, typ) => write!(f, "Literal({lit:?}, {typ:?})"),
103            MirScalarExpr::CallUnmaterializable(func) => {
104                write!(f, "CallUnmaterializable({func:?})")
105            }
106            MirScalarExpr::CallUnary { func, expr } => {
107                write!(f, "CallUnary({func:?}, {expr:?})")
108            }
109            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
110                write!(f, "CallBinary({func:?}, {expr1:?}, {expr2:?})")
111            }
112            MirScalarExpr::CallVariadic { func, exprs } => {
113                write!(f, "CallVariadic({func:?}, {exprs:?})")
114            }
115            MirScalarExpr::If { cond, then, els } => {
116                write!(f, "If({cond:?}, {then:?}, {els:?})")
117            }
118        }
119    }
120}
121
122impl MirScalarExpr {
123    pub fn columns(is: &[usize]) -> Vec<MirScalarExpr> {
124        is.iter().map(|i| MirScalarExpr::column(*i)).collect()
125    }
126
127    pub fn column(column: usize) -> Self {
128        MirScalarExpr::Column(column, TreatAsEqual(None))
129    }
130
131    pub fn named_column(column: usize, name: Arc<str>) -> Self {
132        MirScalarExpr::Column(column, TreatAsEqual(Some(name)))
133    }
134
135    pub fn literal(res: Result<Datum, EvalError>, typ: SqlScalarType) -> Self {
136        let typ = typ.nullable(matches!(res, Ok(Datum::Null)));
137        let row = res.map(|datum| Row::pack_slice(&[datum]));
138        MirScalarExpr::Literal(row, typ)
139    }
140
141    pub fn literal_ok(datum: Datum, typ: SqlScalarType) -> Self {
142        MirScalarExpr::literal(Ok(datum), typ)
143    }
144
145    pub fn literal_null(typ: SqlScalarType) -> Self {
146        MirScalarExpr::literal_ok(Datum::Null, typ)
147    }
148
149    pub fn literal_false() -> Self {
150        MirScalarExpr::literal_ok(Datum::False, SqlScalarType::Bool)
151    }
152
153    pub fn literal_true() -> Self {
154        MirScalarExpr::literal_ok(Datum::True, SqlScalarType::Bool)
155    }
156
157    pub fn call_unary<U: Into<UnaryFunc>>(self, func: U) -> Self {
158        MirScalarExpr::CallUnary {
159            func: func.into(),
160            expr: Box::new(self),
161        }
162    }
163
164    pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
165        MirScalarExpr::CallBinary {
166            func: func.into(),
167            expr1: Box::new(self),
168            expr2: Box::new(other),
169        }
170    }
171
172    pub fn if_then_else(self, t: Self, f: Self) -> Self {
173        MirScalarExpr::If {
174            cond: Box::new(self),
175            then: Box::new(t),
176            els: Box::new(f),
177        }
178    }
179
180    pub fn or(self, other: Self) -> Self {
181        MirScalarExpr::CallVariadic {
182            func: VariadicFunc::Or,
183            exprs: vec![self, other],
184        }
185    }
186
187    pub fn and(self, other: Self) -> Self {
188        MirScalarExpr::CallVariadic {
189            func: VariadicFunc::And,
190            exprs: vec![self, other],
191        }
192    }
193
194    pub fn not(self) -> Self {
195        self.call_unary(UnaryFunc::Not(func::Not))
196    }
197
198    pub fn call_is_null(self) -> Self {
199        self.call_unary(UnaryFunc::IsNull(func::IsNull))
200    }
201
202    /// Match AND or OR on self and get the args. If no match, then interpret self as if it were
203    /// wrapped in a 1-arg AND/OR.
204    pub fn and_or_args(&self, func_to_match: VariadicFunc) -> Vec<MirScalarExpr> {
205        assert!(func_to_match == VariadicFunc::Or || func_to_match == VariadicFunc::And);
206        match self {
207            MirScalarExpr::CallVariadic { func, exprs } if *func == func_to_match => exprs.clone(),
208            _ => vec![self.clone()],
209        }
210    }
211
212    /// Try to match a literal equality involving the given expression on one side.
213    /// Return the (non-null) literal and a bool that indicates whether an inversion was needed.
214    ///
215    /// More specifically:
216    /// If `self` is an equality with a `null` literal on any side, then the match fails!
217    /// Otherwise: for a given `expr`, if `self` is `<expr> = <literal>` or `<literal> = <expr>`
218    /// then return `Some((<literal>, false))`. In addition to just trying to match `<expr>` as it
219    /// is, we also try to remove an invertible function call (such as a cast). If the match
220    /// succeeds with the inversion, then return `Some((<inverted-literal>, true))`. For more
221    /// details on the inversion, see `invert_casts_on_expr_eq_literal_inner`.
222    pub fn expr_eq_literal(&self, expr: &MirScalarExpr) -> Option<(Row, bool)> {
223        if let MirScalarExpr::CallBinary {
224            func: BinaryFunc::Eq(_),
225            expr1,
226            expr2,
227        } = self
228        {
229            if expr1.is_literal_null() || expr2.is_literal_null() {
230                return None;
231            }
232            if let Some(Ok(lit)) = expr1.as_literal_owned() {
233                return Self::expr_eq_literal_inner(expr, lit, expr1, expr2);
234            }
235            if let Some(Ok(lit)) = expr2.as_literal_owned() {
236                return Self::expr_eq_literal_inner(expr, lit, expr2, expr1);
237            }
238        }
239        None
240    }
241
242    fn expr_eq_literal_inner(
243        expr_to_match: &MirScalarExpr,
244        literal: Row,
245        literal_expr: &MirScalarExpr,
246        other_side: &MirScalarExpr,
247    ) -> Option<(Row, bool)> {
248        if other_side == expr_to_match {
249            return Some((literal, false));
250        } else {
251            // expr didn't exactly match. See if we can match it by inverse-casting.
252            let (cast_removed, inv_cast_lit) =
253                Self::invert_casts_on_expr_eq_literal_inner(other_side, literal_expr);
254            if &cast_removed == expr_to_match {
255                if let Some(Ok(inv_cast_lit_row)) = inv_cast_lit.as_literal_owned() {
256                    return Some((inv_cast_lit_row, true));
257                }
258            }
259        }
260        None
261    }
262
263    /// If `self` is `<expr> = <literal>` or `<literal> = <expr>` then
264    /// return `<expr>`. It also tries to remove a cast (or other invertible function call) from
265    /// `<expr>` before returning it, see `invert_casts_on_expr_eq_literal_inner`.
266    pub fn any_expr_eq_literal(&self) -> Option<MirScalarExpr> {
267        if let MirScalarExpr::CallBinary {
268            func: BinaryFunc::Eq(_),
269            expr1,
270            expr2,
271        } = self
272        {
273            if expr1.is_literal() {
274                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
275                return Some(expr);
276            }
277            if expr2.is_literal() {
278                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
279                return Some(expr);
280            }
281        }
282        None
283    }
284
285    /// If the given `MirScalarExpr` is a literal equality where one side is an invertible function
286    /// call, then calls the inverse function on both sides of the equality and returns the modified
287    /// version of the given `MirScalarExpr`. Otherwise, it returns the original expression.
288    /// For more details, see `invert_casts_on_expr_eq_literal_inner`.
289    pub fn invert_casts_on_expr_eq_literal(&self) -> MirScalarExpr {
290        if let MirScalarExpr::CallBinary {
291            func: BinaryFunc::Eq(_),
292            expr1,
293            expr2,
294        } = self
295        {
296            if expr1.is_literal() {
297                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
298                return literal.call_binary(expr, func::Eq);
299            }
300            if expr2.is_literal() {
301                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
302                return literal.call_binary(expr, func::Eq);
303            }
304            // Note: The above return statements should be consistent in whether they put the
305            // literal in expr1 or expr2, for the deduplication in CanonicalizeMfp to work.
306        }
307        self.clone()
308    }
309
310    /// Given an `<expr>` and a `<literal>` that were taken out from `<expr> = <literal>` or
311    /// `<literal> = <expr>`, it tries to simplify the equality by applying the inverse function of
312    /// the outermost function call of `<expr>` (if exists):
313    ///
314    /// `<literal> = func(<inner_expr>)`, where `func` is invertible
315    ///  -->
316    /// `<func^-1(literal)> = <inner_expr>`
317    /// if `func^-1(literal)` doesn't error out, and both `func` and `func^-1` preserve uniqueness.
318    ///
319    /// The return value is the `<inner_expr>` and the literal value that we get by applying the
320    /// inverse function.
321    fn invert_casts_on_expr_eq_literal_inner(
322        expr: &MirScalarExpr,
323        literal: &MirScalarExpr,
324    ) -> (MirScalarExpr, MirScalarExpr) {
325        assert!(matches!(literal, MirScalarExpr::Literal(..)));
326
327        let temp_storage = &RowArena::new();
328        let eval = |e: &MirScalarExpr| {
329            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
330        };
331
332        if let MirScalarExpr::CallUnary {
333            func,
334            expr: inner_expr,
335        } = expr
336        {
337            if let Some(inverse_func) = func.inverse() {
338                // We don't want to remove a function call that doesn't preserve uniqueness, e.g.,
339                // if `f` is a float, we don't want to inverse-cast `f::INT = 0`, because the
340                // inserted int-to-float cast wouldn't be able to invert the rounding.
341                // Also, we don't want to insert a function call that doesn't preserve
342                // uniqueness. E.g., if `a` has an integer type, we don't want to do
343                // a surprise rounding for `WHERE a = 3.14`.
344                if func.preserves_uniqueness() && inverse_func.preserves_uniqueness() {
345                    let lit_inv = eval(&MirScalarExpr::CallUnary {
346                        func: inverse_func,
347                        expr: Box::new(literal.clone()),
348                    });
349                    // The evaluation can error out, e.g., when casting a too large int32 to int16.
350                    // This case is handled by `impossible_literal_equality_because_types`.
351                    if !lit_inv.is_literal_err() {
352                        return (*inner_expr.clone(), lit_inv);
353                    }
354                }
355            }
356        }
357        (expr.clone(), literal.clone())
358    }
359
360    /// Tries to remove a cast (or other invertible function) in the same way as
361    /// `invert_casts_on_expr_eq_literal`, but if calling the inverse function fails on the literal,
362    /// then it deems the equality to be impossible. For example if `a` is a smallint column, then
363    /// it catches `a::integer = 1000000` to be an always false predicate (where the `::integer`
364    /// could have been inserted implicitly).
365    pub fn impossible_literal_equality_because_types(&self) -> bool {
366        if let MirScalarExpr::CallBinary {
367            func: BinaryFunc::Eq(_),
368            expr1,
369            expr2,
370        } = self
371        {
372            if expr1.is_literal() {
373                return Self::impossible_literal_equality_because_types_inner(expr1, expr2);
374            }
375            if expr2.is_literal() {
376                return Self::impossible_literal_equality_because_types_inner(expr2, expr1);
377            }
378        }
379        false
380    }
381
382    fn impossible_literal_equality_because_types_inner(
383        literal: &MirScalarExpr,
384        other_side: &MirScalarExpr,
385    ) -> bool {
386        assert!(matches!(literal, MirScalarExpr::Literal(..)));
387
388        let temp_storage = &RowArena::new();
389        let eval = |e: &MirScalarExpr| {
390            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
391        };
392
393        if let MirScalarExpr::CallUnary { func, .. } = other_side {
394            if let Some(inverse_func) = func.inverse() {
395                if inverse_func.preserves_uniqueness()
396                    && eval(&MirScalarExpr::CallUnary {
397                        func: inverse_func,
398                        expr: Box::new(literal.clone()),
399                    })
400                    .is_literal_err()
401                {
402                    return true;
403                }
404            }
405        }
406
407        false
408    }
409
410    /// Determines if `self` is
411    /// `<expr> < <literal>` or
412    /// `<expr> > <literal>` or
413    /// `<literal> < <expr>` or
414    /// `<literal> > <expr>` or
415    /// `<expr> <= <literal>` or
416    /// `<expr> >= <literal>` or
417    /// `<literal> <= <expr>` or
418    /// `<literal> >= <expr>`.
419    pub fn any_expr_ineq_literal(&self) -> bool {
420        match self {
421            MirScalarExpr::CallBinary {
422                func:
423                    BinaryFunc::Lt(_) | BinaryFunc::Lte(_) | BinaryFunc::Gt(_) | BinaryFunc::Gte(_),
424                expr1,
425                expr2,
426            } => expr1.is_literal() || expr2.is_literal(),
427            _ => false,
428        }
429    }
430
431    /// Rewrites column indices with their value in `permutation`.
432    ///
433    /// This method is applicable even when `permutation` is not a
434    /// strict permutation, and it only needs to have entries for
435    /// each column referenced in `self`.
436    pub fn permute(&mut self, permutation: &[usize]) {
437        self.visit_columns(|c| *c = permutation[*c]);
438    }
439
440    /// Rewrites column indices with their value in `permutation`.
441    ///
442    /// This method is applicable even when `permutation` is not a
443    /// strict permutation, and it only needs to have entries for
444    /// each column referenced in `self`.
445    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
446        self.visit_columns(|c| *c = permutation[c]);
447    }
448
449    /// Visits each column reference and applies `action` to the column.
450    ///
451    /// Useful for remapping columns, or for collecting expression support.
452    pub fn visit_columns<F>(&mut self, mut action: F)
453    where
454        F: FnMut(&mut usize),
455    {
456        self.visit_pre_mut(|e| {
457            if let MirScalarExpr::Column(col, _) = e {
458                action(col);
459            }
460        });
461    }
462
463    pub fn support(&self) -> BTreeSet<usize> {
464        let mut support = BTreeSet::new();
465        self.support_into(&mut support);
466        support
467    }
468
469    pub fn support_into(&self, support: &mut BTreeSet<usize>) {
470        self.visit_pre(|e| {
471            if let MirScalarExpr::Column(i, _) = e {
472                support.insert(*i);
473            }
474        });
475    }
476
477    pub fn take(&mut self) -> Self {
478        mem::replace(self, MirScalarExpr::literal_null(SqlScalarType::String))
479    }
480
481    /// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
482    /// Otherwise, it returns None.
483    pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
484        if let MirScalarExpr::Literal(lit, _column_type) = self {
485            Some(lit.as_ref().map(|row| row.unpack_first()))
486        } else {
487            None
488        }
489    }
490
491    /// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
492    /// literal's Datum only if the expression is a literal, and it's not a literal error.
493    pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
494        self.as_literal().map(|eval_err| eval_err.ok()).flatten()
495    }
496
497    pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
498        if let MirScalarExpr::Literal(lit, _column_type) = self {
499            Some(lit.clone())
500        } else {
501            None
502        }
503    }
504
505    pub fn as_literal_str(&self) -> Option<&str> {
506        match self.as_literal() {
507            Some(Ok(Datum::String(s))) => Some(s),
508            _ => None,
509        }
510    }
511
512    pub fn as_literal_int64(&self) -> Option<i64> {
513        match self.as_literal() {
514            Some(Ok(Datum::Int64(i))) => Some(i),
515            _ => None,
516        }
517    }
518
519    pub fn as_literal_err(&self) -> Option<&EvalError> {
520        self.as_literal().and_then(|lit| lit.err())
521    }
522
523    pub fn is_literal(&self) -> bool {
524        matches!(self, MirScalarExpr::Literal(_, _))
525    }
526
527    pub fn is_literal_true(&self) -> bool {
528        Some(Ok(Datum::True)) == self.as_literal()
529    }
530
531    pub fn is_literal_false(&self) -> bool {
532        Some(Ok(Datum::False)) == self.as_literal()
533    }
534
535    pub fn is_literal_null(&self) -> bool {
536        Some(Ok(Datum::Null)) == self.as_literal()
537    }
538
539    pub fn is_literal_ok(&self) -> bool {
540        matches!(self, MirScalarExpr::Literal(Ok(_), _typ))
541    }
542
543    pub fn is_literal_err(&self) -> bool {
544        matches!(self, MirScalarExpr::Literal(Err(_), _typ))
545    }
546
547    pub fn is_column(&self) -> bool {
548        matches!(self, MirScalarExpr::Column(_col, _name))
549    }
550
551    pub fn is_error_if_null(&self) -> bool {
552        matches!(
553            self,
554            Self::CallVariadic {
555                func: VariadicFunc::ErrorIfNull,
556                ..
557            }
558        )
559    }
560
561    /// If `self` expresses a temporal filter, normalize it to start with `mz_now()` and return
562    /// references.
563    ///
564    /// A temporal filter is an expression of the form `mz_now() <BINOP> <EXPR>`,
565    /// for a restricted set of `BINOP` and `EXPR` that do not themselves contain `mz_now()`.
566    /// Expressions may conform to this once their expressions are swapped.
567    ///
568    /// If the expression is not a temporal filter, it will be unchanged, and the reason for why
569    /// it's not a temporal filter is returned as a string.
570    pub fn as_mut_temporal_filter(&mut self) -> Result<(&BinaryFunc, &mut MirScalarExpr), String> {
571        if !self.contains_temporal() {
572            return Err("Does not involve mz_now()".to_string());
573        }
574        // Supported temporal predicates are exclusively binary operators.
575        if let MirScalarExpr::CallBinary { func, expr1, expr2 } = self {
576            // Attempt to put `LogicalTimestamp` in the first argument position.
577            if !expr1.contains_temporal()
578                && **expr2 == MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
579            {
580                std::mem::swap(expr1, expr2);
581                *func = match func {
582                    BinaryFunc::Eq(_) => func::Eq.into(),
583                    BinaryFunc::Lt(_) => func::Gt.into(),
584                    BinaryFunc::Lte(_) => func::Gte.into(),
585                    BinaryFunc::Gt(_) => func::Lt.into(),
586                    BinaryFunc::Gte(_) => func::Lte.into(),
587                    x => {
588                        return Err(format!("Unsupported binary temporal operation: {:?}", x));
589                    }
590                };
591            }
592
593            // Error if MLT is referenced in an unsupported position.
594            if expr2.contains_temporal()
595                || **expr1 != MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow)
596            {
597                return Err(format!(
598                    "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a mz_timestamp-castable expression. Expression found: {}",
599                    MirScalarExpr::CallBinary {
600                        func: func.clone(),
601                        expr1: expr1.clone(),
602                        expr2: expr2.clone()
603                    },
604                ));
605            }
606
607            Ok((&*func, expr2))
608        } else {
609            Err(format!(
610                "Unsupported temporal predicate. Note: `mz_now()` must be directly compared to a non-temporal expression of mz_timestamp-castable type. Expression found: {}",
611                self,
612            ))
613        }
614    }
615
616    #[deprecated = "Use `might_error` instead"]
617    pub fn contains_error_if_null(&self) -> bool {
618        let mut worklist = vec![self];
619        while let Some(expr) = worklist.pop() {
620            if expr.is_error_if_null() {
621                return true;
622            }
623            worklist.extend(expr.children());
624        }
625        false
626    }
627
628    pub fn contains_err(&self) -> bool {
629        let mut worklist = vec![self];
630        while let Some(expr) = worklist.pop() {
631            if expr.is_literal_err() {
632                return true;
633            }
634            worklist.extend(expr.children());
635        }
636        false
637    }
638
639    /// A very crude approximation for scalar expressions that might produce an
640    /// error.
641    ///
642    /// Currently, this is restricted only to expressions that either contain a
643    /// literal error or a [`VariadicFunc::ErrorIfNull`] call.
644    pub fn might_error(&self) -> bool {
645        let mut worklist = vec![self];
646        while let Some(expr) = worklist.pop() {
647            if expr.is_literal_err() || expr.is_error_if_null() {
648                return true;
649            }
650            worklist.extend(expr.children());
651        }
652        false
653    }
654
655    /// If self is a column, return the column index, otherwise `None`.
656    pub fn as_column(&self) -> Option<usize> {
657        if let MirScalarExpr::Column(c, _) = self {
658            Some(*c)
659        } else {
660            None
661        }
662    }
663
664    /// Reduces a complex expression where possible.
665    ///
666    /// This function uses nullability information present in `column_types`,
667    /// and the result may only continue to be a correct transformation as
668    /// long as this information continues to hold (nullability may not hold
669    /// as expressions migrate around).
670    ///
671    /// (If you'd like to not use nullability information here, then you can
672    /// tweak the nullabilities in `column_types` before passing it to this
673    /// function, see e.g. in `EquivalenceClasses::minimize`.)
674    ///
675    /// Also performs partial canonicalization on the expression.
676    ///
677    /// ```rust
678    /// use mz_expr::MirScalarExpr;
679    /// use mz_repr::{SqlColumnType, Datum, SqlScalarType};
680    ///
681    /// let expr_0 = MirScalarExpr::column(0);
682    /// let expr_t = MirScalarExpr::literal_true();
683    /// let expr_f = MirScalarExpr::literal_false();
684    ///
685    /// let mut test =
686    /// expr_t
687    ///     .clone()
688    ///     .and(expr_f.clone())
689    ///     .if_then_else(expr_0, expr_t.clone());
690    ///
691    /// let input_type = vec![SqlScalarType::Int32.nullable(false)];
692    /// test.reduce(&input_type);
693    /// assert_eq!(test, expr_t);
694    /// ```
695    pub fn reduce(&mut self, column_types: &[SqlColumnType]) {
696        let temp_storage = &RowArena::new();
697        let eval = |e: &MirScalarExpr| {
698            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type)
699        };
700
701        // Simplifications run in a loop until `self` no longer changes.
702        let mut old_self = MirScalarExpr::column(0);
703        while old_self != *self {
704            old_self = self.clone();
705            #[allow(deprecated)]
706            self.visit_mut_pre_post_nolimit(
707                &mut |e| {
708                    match e {
709                        MirScalarExpr::CallUnary { func, expr } => {
710                            if *func == UnaryFunc::IsNull(func::IsNull) {
711                                if !expr.typ(column_types).nullable {
712                                    *e = MirScalarExpr::literal_false();
713                                } else {
714                                    // Try to at least decompose IsNull into a disjunction
715                                    // of simpler IsNull subexpressions.
716                                    if let Some(expr) = expr.decompose_is_null() {
717                                        *e = expr
718                                    }
719                                }
720                            } else if *func == UnaryFunc::Not(func::Not) {
721                                // Push down not expressions
722                                match &mut **expr {
723                                    // Two negates cancel each other out.
724                                    MirScalarExpr::CallUnary {
725                                        expr: inner_expr,
726                                        func: UnaryFunc::Not(func::Not),
727                                    } => *e = inner_expr.take(),
728                                    // Transforms `NOT(a <op> b)` to `a negate(<op>) b`
729                                    // if a negation exists.
730                                    MirScalarExpr::CallBinary { expr1, expr2, func } => {
731                                        if let Some(negated_func) = func.negate() {
732                                            *e = MirScalarExpr::CallBinary {
733                                                expr1: Box::new(expr1.take()),
734                                                expr2: Box::new(expr2.take()),
735                                                func: negated_func,
736                                            }
737                                        }
738                                    }
739                                    MirScalarExpr::CallVariadic { .. } => {
740                                        e.demorgans();
741                                    }
742                                    _ => {}
743                                }
744                            }
745                        }
746                        _ => {}
747                    };
748                    None
749                },
750                &mut |e| match e {
751                    // Evaluate and pull up constants
752                    MirScalarExpr::Column(_, _)
753                    | MirScalarExpr::Literal(_, _)
754                    | MirScalarExpr::CallUnmaterializable(_) => (),
755                    MirScalarExpr::CallUnary { func, expr } => {
756                        if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
757                            *e = eval(e);
758                        } else if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
759                            if let MirScalarExpr::CallVariadic {
760                                func: VariadicFunc::RecordCreate { .. },
761                                exprs,
762                            } = &mut **expr
763                            {
764                                *e = exprs.swap_remove(i);
765                            }
766                        }
767                    }
768                    MirScalarExpr::CallBinary { func, expr1, expr2 } => {
769                        if expr1.is_literal() && expr2.is_literal() {
770                            *e = eval(e);
771                        } else if (expr1.is_literal_null() || expr2.is_literal_null())
772                            && func.propagates_nulls()
773                        {
774                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
775                        } else if let Some(err) = expr1.as_literal_err() {
776                            *e = MirScalarExpr::literal(
777                                Err(err.clone()),
778                                e.typ(column_types).scalar_type,
779                            );
780                        } else if let Some(err) = expr2.as_literal_err() {
781                            *e = MirScalarExpr::literal(
782                                Err(err.clone()),
783                                e.typ(column_types).scalar_type,
784                            );
785                        } else if let BinaryFunc::IsLikeMatchCaseInsensitive(_) = func {
786                            if expr2.is_literal() {
787                                // We can at least precompile the regex.
788                                let pattern = expr2.as_literal_str().unwrap();
789                                *e = match like_pattern::compile(pattern, true) {
790                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
791                                        func::IsLikeMatch(matcher),
792                                    )),
793                                    Err(err) => MirScalarExpr::literal(
794                                        Err(err),
795                                        e.typ(column_types).scalar_type,
796                                    ),
797                                };
798                            }
799                        } else if let BinaryFunc::IsLikeMatchCaseSensitive(_) = func {
800                            if expr2.is_literal() {
801                                // We can at least precompile the regex.
802                                let pattern = expr2.as_literal_str().unwrap();
803                                *e = match like_pattern::compile(pattern, false) {
804                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
805                                        func::IsLikeMatch(matcher),
806                                    )),
807                                    Err(err) => MirScalarExpr::literal(
808                                        Err(err),
809                                        e.typ(column_types).scalar_type,
810                                    ),
811                                };
812                            }
813                        } else if matches!(
814                            func,
815                            BinaryFunc::IsRegexpMatchCaseSensitive(_)
816                                | BinaryFunc::IsRegexpMatchCaseInsensitive(_)
817                        ) {
818                            let case_insensitive =
819                                matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
820                            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
821                                *e = match Regex::new(
822                                    row.unpack_first().unwrap_str(),
823                                    case_insensitive,
824                                ) {
825                                    Ok(regex) => expr1.take().call_unary(UnaryFunc::IsRegexpMatch(
826                                        func::IsRegexpMatch(regex),
827                                    )),
828                                    Err(err) => MirScalarExpr::literal(
829                                        Err(err.into()),
830                                        e.typ(column_types).scalar_type,
831                                    ),
832                                };
833                            }
834                        } else if let BinaryFunc::ExtractInterval(_) = *func
835                            && expr1.is_literal()
836                        {
837                            let units = expr1.as_literal_str().unwrap();
838                            *e = match units.parse::<DateTimeUnits>() {
839                                Ok(units) => MirScalarExpr::CallUnary {
840                                    func: UnaryFunc::ExtractInterval(func::ExtractInterval(units)),
841                                    expr: Box::new(expr2.take()),
842                                },
843                                Err(_) => MirScalarExpr::literal(
844                                    Err(EvalError::UnknownUnits(units.into())),
845                                    e.typ(column_types).scalar_type,
846                                ),
847                            }
848                        } else if let BinaryFunc::ExtractTime(_) = *func
849                            && expr1.is_literal()
850                        {
851                            let units = expr1.as_literal_str().unwrap();
852                            *e = match units.parse::<DateTimeUnits>() {
853                                Ok(units) => MirScalarExpr::CallUnary {
854                                    func: UnaryFunc::ExtractTime(func::ExtractTime(units)),
855                                    expr: Box::new(expr2.take()),
856                                },
857                                Err(_) => MirScalarExpr::literal(
858                                    Err(EvalError::UnknownUnits(units.into())),
859                                    e.typ(column_types).scalar_type,
860                                ),
861                            }
862                        } else if let BinaryFunc::ExtractTimestamp(_) = *func
863                            && expr1.is_literal()
864                        {
865                            let units = expr1.as_literal_str().unwrap();
866                            *e = match units.parse::<DateTimeUnits>() {
867                                Ok(units) => MirScalarExpr::CallUnary {
868                                    func: UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(
869                                        units,
870                                    )),
871                                    expr: Box::new(expr2.take()),
872                                },
873                                Err(_) => MirScalarExpr::literal(
874                                    Err(EvalError::UnknownUnits(units.into())),
875                                    e.typ(column_types).scalar_type,
876                                ),
877                            }
878                        } else if let BinaryFunc::ExtractTimestampTz(_) = *func
879                            && expr1.is_literal()
880                        {
881                            let units = expr1.as_literal_str().unwrap();
882                            *e = match units.parse::<DateTimeUnits>() {
883                                Ok(units) => MirScalarExpr::CallUnary {
884                                    func: UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(
885                                        units,
886                                    )),
887                                    expr: Box::new(expr2.take()),
888                                },
889                                Err(_) => MirScalarExpr::literal(
890                                    Err(EvalError::UnknownUnits(units.into())),
891                                    e.typ(column_types).scalar_type,
892                                ),
893                            }
894                        } else if let BinaryFunc::ExtractDate(_) = *func
895                            && expr1.is_literal()
896                        {
897                            let units = expr1.as_literal_str().unwrap();
898                            *e = match units.parse::<DateTimeUnits>() {
899                                Ok(units) => MirScalarExpr::CallUnary {
900                                    func: UnaryFunc::ExtractDate(func::ExtractDate(units)),
901                                    expr: Box::new(expr2.take()),
902                                },
903                                Err(_) => MirScalarExpr::literal(
904                                    Err(EvalError::UnknownUnits(units.into())),
905                                    e.typ(column_types).scalar_type,
906                                ),
907                            }
908                        } else if let BinaryFunc::DatePartInterval(_) = *func
909                            && expr1.is_literal()
910                        {
911                            let units = expr1.as_literal_str().unwrap();
912                            *e = match units.parse::<DateTimeUnits>() {
913                                Ok(units) => MirScalarExpr::CallUnary {
914                                    func: UnaryFunc::DatePartInterval(func::DatePartInterval(
915                                        units,
916                                    )),
917                                    expr: Box::new(expr2.take()),
918                                },
919                                Err(_) => MirScalarExpr::literal(
920                                    Err(EvalError::UnknownUnits(units.into())),
921                                    e.typ(column_types).scalar_type,
922                                ),
923                            }
924                        } else if let BinaryFunc::DatePartTime(_) = *func
925                            && expr1.is_literal()
926                        {
927                            let units = expr1.as_literal_str().unwrap();
928                            *e = match units.parse::<DateTimeUnits>() {
929                                Ok(units) => MirScalarExpr::CallUnary {
930                                    func: UnaryFunc::DatePartTime(func::DatePartTime(units)),
931                                    expr: Box::new(expr2.take()),
932                                },
933                                Err(_) => MirScalarExpr::literal(
934                                    Err(EvalError::UnknownUnits(units.into())),
935                                    e.typ(column_types).scalar_type,
936                                ),
937                            }
938                        } else if let BinaryFunc::DatePartTimestamp(_) = *func
939                            && expr1.is_literal()
940                        {
941                            let units = expr1.as_literal_str().unwrap();
942                            *e = match units.parse::<DateTimeUnits>() {
943                                Ok(units) => MirScalarExpr::CallUnary {
944                                    func: UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(
945                                        units,
946                                    )),
947                                    expr: Box::new(expr2.take()),
948                                },
949                                Err(_) => MirScalarExpr::literal(
950                                    Err(EvalError::UnknownUnits(units.into())),
951                                    e.typ(column_types).scalar_type,
952                                ),
953                            }
954                        } else if let BinaryFunc::DatePartTimestampTz(_) = *func
955                            && expr1.is_literal()
956                        {
957                            let units = expr1.as_literal_str().unwrap();
958                            *e = match units.parse::<DateTimeUnits>() {
959                                Ok(units) => MirScalarExpr::CallUnary {
960                                    func: UnaryFunc::DatePartTimestampTz(
961                                        func::DatePartTimestampTz(units),
962                                    ),
963                                    expr: Box::new(expr2.take()),
964                                },
965                                Err(_) => MirScalarExpr::literal(
966                                    Err(EvalError::UnknownUnits(units.into())),
967                                    e.typ(column_types).scalar_type,
968                                ),
969                            }
970                        } else if let BinaryFunc::DateTruncTimestamp(_) = *func
971                            && expr1.is_literal()
972                        {
973                            let units = expr1.as_literal_str().unwrap();
974                            *e = match units.parse::<DateTimeUnits>() {
975                                Ok(units) => MirScalarExpr::CallUnary {
976                                    func: UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(
977                                        units,
978                                    )),
979                                    expr: Box::new(expr2.take()),
980                                },
981                                Err(_) => MirScalarExpr::literal(
982                                    Err(EvalError::UnknownUnits(units.into())),
983                                    e.typ(column_types).scalar_type,
984                                ),
985                            }
986                        } else if let BinaryFunc::DateTruncTimestampTz(_) = *func
987                            && expr1.is_literal()
988                        {
989                            let units = expr1.as_literal_str().unwrap();
990                            *e = match units.parse::<DateTimeUnits>() {
991                                Ok(units) => MirScalarExpr::CallUnary {
992                                    func: UnaryFunc::DateTruncTimestampTz(
993                                        func::DateTruncTimestampTz(units),
994                                    ),
995                                    expr: Box::new(expr2.take()),
996                                },
997                                Err(_) => MirScalarExpr::literal(
998                                    Err(EvalError::UnknownUnits(units.into())),
999                                    e.typ(column_types).scalar_type,
1000                                ),
1001                            }
1002                        } else if matches!(func, BinaryFunc::TimezoneTimestampBinary(_))
1003                            && expr1.is_literal()
1004                        {
1005                            // If the timezone argument is a literal, and we're applying the function on many rows at the same
1006                            // time we really don't want to parse it again and again, so we parse it once and embed it into the
1007                            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
1008                            let tz = expr1.as_literal_str().unwrap();
1009                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1010                                Ok(tz) => MirScalarExpr::CallUnary {
1011                                    func: UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz)),
1012                                    expr: Box::new(expr2.take()),
1013                                },
1014                                Err(err) => MirScalarExpr::literal(
1015                                    Err(err),
1016                                    e.typ(column_types).scalar_type,
1017                                ),
1018                            }
1019                        } else if matches!(func, BinaryFunc::TimezoneTimestampTzBinary(_))
1020                            && expr1.is_literal()
1021                        {
1022                            let tz = expr1.as_literal_str().unwrap();
1023                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1024                                Ok(tz) => MirScalarExpr::CallUnary {
1025                                    func: UnaryFunc::TimezoneTimestampTz(
1026                                        func::TimezoneTimestampTz(tz),
1027                                    ),
1028                                    expr: Box::new(expr2.take()),
1029                                },
1030                                Err(err) => MirScalarExpr::literal(
1031                                    Err(err),
1032                                    e.typ(column_types).scalar_type,
1033                                ),
1034                            }
1035                        } else if let BinaryFunc::ToCharTimestamp(_) = *func
1036                            && expr2.is_literal()
1037                        {
1038                            let format_str = expr2.as_literal_str().unwrap();
1039                            *e = MirScalarExpr::CallUnary {
1040                                func: UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
1041                                    format_string: format_str.to_string(),
1042                                    format: DateTimeFormat::compile(format_str),
1043                                }),
1044                                expr: Box::new(expr1.take()),
1045                            };
1046                        } else if let BinaryFunc::ToCharTimestampTz(_) = *func
1047                            && expr2.is_literal()
1048                        {
1049                            let format_str = expr2.as_literal_str().unwrap();
1050                            *e = MirScalarExpr::CallUnary {
1051                                func: UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
1052                                    format_string: format_str.to_string(),
1053                                    format: DateTimeFormat::compile(format_str),
1054                                }),
1055                                expr: Box::new(expr1.take()),
1056                            };
1057                        } else if matches!(*func, BinaryFunc::Eq(_) | BinaryFunc::NotEq(_))
1058                            && expr2 < expr1
1059                        {
1060                            // Canonically order elements so that deduplication works better.
1061                            // Also, the below `Literal([c1, c2]) = record_create(e1, e2)` matching
1062                            // relies on this canonical ordering.
1063                            mem::swap(expr1, expr2);
1064                        } else if let (
1065                            BinaryFunc::Eq(_),
1066                            MirScalarExpr::Literal(
1067                                Ok(lit_row),
1068                                SqlColumnType {
1069                                    scalar_type:
1070                                        SqlScalarType::Record {
1071                                            fields: field_types,
1072                                            ..
1073                                        },
1074                                    ..
1075                                },
1076                            ),
1077                            MirScalarExpr::CallVariadic {
1078                                func: VariadicFunc::RecordCreate { .. },
1079                                exprs: rec_create_args,
1080                            },
1081                        ) = (&*func, &**expr1, &**expr2)
1082                        {
1083                            // Literal([c1, c2]) = record_create(e1, e2)
1084                            //  -->
1085                            // c1 = e1 AND c2 = e2
1086                            //
1087                            // (Records are represented as lists.)
1088                            //
1089                            // `MapFilterProject::literal_constraints` relies on this transform,
1090                            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
1091                            match lit_row.unpack_first() {
1092                                Datum::List(datum_list) => {
1093                                    *e = MirScalarExpr::CallVariadic {
1094                                        func: VariadicFunc::And,
1095                                        exprs: datum_list
1096                                            .iter()
1097                                            .zip_eq(field_types)
1098                                            .zip_eq(rec_create_args)
1099                                            .map(|((d, (_, typ)), a)| {
1100                                                MirScalarExpr::literal_ok(
1101                                                    d,
1102                                                    typ.scalar_type.clone(),
1103                                                )
1104                                                .call_binary(a.clone(), func::Eq)
1105                                            })
1106                                            .collect(),
1107                                    };
1108                                }
1109                                _ => {}
1110                            }
1111                        } else if let (
1112                            BinaryFunc::Eq(_),
1113                            MirScalarExpr::CallVariadic {
1114                                func: VariadicFunc::RecordCreate { .. },
1115                                exprs: rec_create_args1,
1116                            },
1117                            MirScalarExpr::CallVariadic {
1118                                func: VariadicFunc::RecordCreate { .. },
1119                                exprs: rec_create_args2,
1120                            },
1121                        ) = (&*func, &**expr1, &**expr2)
1122                        {
1123                            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
1124                            //  -->
1125                            // a1 = b1 AND a2 = b2 AND ...
1126                            //
1127                            // This is similar to the previous reduction, but this one kicks in also
1128                            // when only some (or none) of the record fields are literals. This
1129                            // enables the discovery of literal constraints for those fields.
1130                            //
1131                            // Note that there is a similar decomposition in
1132                            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
1133                            // pipeline than the compilation of IN lists to `record_create`.
1134                            *e = MirScalarExpr::CallVariadic {
1135                                func: VariadicFunc::And,
1136                                exprs: rec_create_args1
1137                                    .into_iter()
1138                                    .zip_eq(rec_create_args2)
1139                                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
1140                                    .collect(),
1141                            }
1142                        }
1143                    }
1144                    MirScalarExpr::CallVariadic { .. } => {
1145                        e.flatten_associative();
1146                        let (func, exprs) = match e {
1147                            MirScalarExpr::CallVariadic { func, exprs } => (func, exprs),
1148                            _ => unreachable!("`flatten_associative` shouldn't change node type"),
1149                        };
1150                        if *func == VariadicFunc::Coalesce {
1151                            // If all inputs are null, output is null. This check must
1152                            // be done before `exprs.retain...` because `e.typ` requires
1153                            // > 0 `exprs` remain.
1154                            if exprs.iter().all(|expr| expr.is_literal_null()) {
1155                                *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1156                                return;
1157                            }
1158
1159                            // Remove any null values if not all values are null.
1160                            exprs.retain(|e| !e.is_literal_null());
1161
1162                            // Find the first argument that is a literal or non-nullable
1163                            // column. All arguments after it get ignored, so throw them
1164                            // away. This intentionally throws away errors that can
1165                            // never happen.
1166                            if let Some(i) = exprs
1167                                .iter()
1168                                .position(|e| e.is_literal() || !e.typ(column_types).nullable)
1169                            {
1170                                exprs.truncate(i + 1);
1171                            }
1172
1173                            // Deduplicate arguments in cases like `coalesce(#0, #0)`.
1174                            let mut prior_exprs = BTreeSet::new();
1175                            exprs.retain(|e| prior_exprs.insert(e.clone()));
1176
1177                            if exprs.len() == 1 {
1178                                // Only one argument, so the coalesce is a no-op.
1179                                *e = exprs[0].take();
1180                            }
1181                        } else if exprs.iter().all(|e| e.is_literal()) {
1182                            *e = eval(e);
1183                        } else if func.propagates_nulls()
1184                            && exprs.iter().any(|e| e.is_literal_null())
1185                        {
1186                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1187                        } else if let Some(err) = exprs.iter().find_map(|e| e.as_literal_err()) {
1188                            *e = MirScalarExpr::literal(
1189                                Err(err.clone()),
1190                                e.typ(column_types).scalar_type,
1191                            );
1192                        } else if *func == VariadicFunc::RegexpMatch
1193                            && exprs[1].is_literal()
1194                            && exprs.get(2).map_or(true, |e| e.is_literal())
1195                        {
1196                            let needle = exprs[1].as_literal_str().unwrap();
1197                            let flags = match exprs.len() {
1198                                3 => exprs[2].as_literal_str().unwrap(),
1199                                _ => "",
1200                            };
1201                            *e = match func::build_regex(needle, flags) {
1202                                Ok(regex) => mem::take(exprs)
1203                                    .into_first()
1204                                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
1205                                Err(err) => MirScalarExpr::literal(
1206                                    Err(err),
1207                                    e.typ(column_types).scalar_type,
1208                                ),
1209                            };
1210                        } else if *func == VariadicFunc::RegexpReplace
1211                            && exprs[1].is_literal()
1212                            && exprs.get(3).map_or(true, |e| e.is_literal())
1213                        {
1214                            let pattern = exprs[1].as_literal_str().unwrap();
1215                            let flags = exprs
1216                                .get(3)
1217                                .map_or("", |expr| expr.as_literal_str().unwrap());
1218                            let (limit, flags) = regexp_replace_parse_flags(flags);
1219
1220                            // The behavior of `regexp_replace` is that if the data is `NULL`, the
1221                            // function returns `NULL`, independently of whether the pattern or
1222                            // flags are correct. We need to check for this case and introduce an
1223                            // if-then-else on the error path to only surface the error if the first
1224                            // input is non-NULL.
1225                            *e = match func::build_regex(pattern, &flags) {
1226                                Ok(regex) => {
1227                                    let mut exprs = mem::take(exprs);
1228                                    let replacement = exprs.swap_remove(2);
1229                                    let source = exprs.swap_remove(0);
1230                                    source.call_binary(
1231                                        replacement,
1232                                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
1233                                    )
1234                                }
1235                                Err(err) => {
1236                                    let mut exprs = mem::take(exprs);
1237                                    let source = exprs.swap_remove(0);
1238                                    let scalar_type = e.typ(column_types).scalar_type;
1239                                    // We need to return `NULL` on `NULL` input, and error otherwise.
1240                                    source.call_is_null().if_then_else(
1241                                        MirScalarExpr::literal_null(scalar_type.clone()),
1242                                        MirScalarExpr::literal(Err(err), scalar_type),
1243                                    )
1244                                }
1245                            };
1246                        } else if *func == VariadicFunc::RegexpSplitToArray
1247                            && exprs[1].is_literal()
1248                            && exprs.get(2).map_or(true, |e| e.is_literal())
1249                        {
1250                            let needle = exprs[1].as_literal_str().unwrap();
1251                            let flags = match exprs.len() {
1252                                3 => exprs[2].as_literal_str().unwrap(),
1253                                _ => "",
1254                            };
1255                            *e = match func::build_regex(needle, flags) {
1256                                Ok(regex) => mem::take(exprs).into_first().call_unary(
1257                                    UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(regex)),
1258                                ),
1259                                Err(err) => MirScalarExpr::literal(
1260                                    Err(err),
1261                                    e.typ(column_types).scalar_type,
1262                                ),
1263                            };
1264                        } else if *func == VariadicFunc::ListIndex && is_list_create_call(&exprs[0])
1265                        {
1266                            // We are looking for ListIndex(ListCreate, literal), and eliminate
1267                            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
1268                            let ind_exprs = exprs.split_off(1);
1269                            let top_list_create = exprs.swap_remove(0);
1270                            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
1271                        } else if *func == VariadicFunc::Or || *func == VariadicFunc::And {
1272                            // Note: It's important that we have called `flatten_associative` above.
1273                            e.undistribute_and_or();
1274                            e.reduce_and_canonicalize_and_or();
1275                        } else if let VariadicFunc::TimezoneTime = func {
1276                            if exprs[0].is_literal() && exprs[2].is_literal_ok() {
1277                                let tz = exprs[0].as_literal_str().unwrap();
1278                                *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1279                                    Ok(tz) => MirScalarExpr::CallUnary {
1280                                        func: UnaryFunc::TimezoneTime(func::TimezoneTime {
1281                                            tz,
1282                                            wall_time: exprs[2]
1283                                                .as_literal()
1284                                                .unwrap()
1285                                                .unwrap()
1286                                                .unwrap_timestamptz()
1287                                                .naive_utc(),
1288                                        }),
1289                                        expr: Box::new(exprs[1].take()),
1290                                    },
1291                                    Err(err) => MirScalarExpr::literal(
1292                                        Err(err),
1293                                        e.typ(column_types).scalar_type,
1294                                    ),
1295                                }
1296                            }
1297                        }
1298                    }
1299                    MirScalarExpr::If { cond, then, els } => {
1300                        if let Some(literal) = cond.as_literal() {
1301                            match literal {
1302                                Ok(Datum::True) => *e = then.take(),
1303                                Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
1304                                Err(err) => {
1305                                    *e = MirScalarExpr::Literal(
1306                                        Err(err.clone()),
1307                                        then.typ(column_types).union(&els.typ(column_types)),
1308                                    )
1309                                }
1310                                _ => unreachable!(),
1311                            }
1312                        } else if then == els {
1313                            *e = then.take();
1314                        } else if then.is_literal_ok()
1315                            && els.is_literal_ok()
1316                            && then.typ(column_types).scalar_type == SqlScalarType::Bool
1317                            && els.typ(column_types).scalar_type == SqlScalarType::Bool
1318                        {
1319                            match (then.as_literal(), els.as_literal()) {
1320                                // Note: NULLs from the condition should not be propagated to the result
1321                                // of the expression.
1322                                (Some(Ok(Datum::True)), _) => {
1323                                    // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
1324                                    // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
1325                                    *e = cond
1326                                        .clone()
1327                                        .call_is_null()
1328                                        .not()
1329                                        .and(cond.take())
1330                                        .or(els.take());
1331                                }
1332                                (Some(Ok(Datum::False)), _) => {
1333                                    // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
1334                                    // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
1335                                    *e = cond
1336                                        .clone()
1337                                        .not()
1338                                        .or(cond.take().call_is_null())
1339                                        .and(els.take());
1340                                }
1341                                (_, Some(Ok(Datum::True))) => {
1342                                    // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
1343                                    // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
1344                                    *e = cond
1345                                        .clone()
1346                                        .not()
1347                                        .or(cond.take().call_is_null())
1348                                        .or(then.take());
1349                                }
1350                                (_, Some(Ok(Datum::False))) => {
1351                                    // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
1352                                    // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
1353                                    *e = cond
1354                                        .clone()
1355                                        .call_is_null()
1356                                        .not()
1357                                        .and(cond.take())
1358                                        .and(then.take());
1359                                }
1360                                _ => {}
1361                            }
1362                        } else {
1363                            // Equivalent expression structure would allow us to push the `If` into the expression.
1364                            // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
1365                            //
1366                            // We have to also make sure that the expressions that will end up in
1367                            // the two `If` branches have unionable types. Otherwise, the `If` could
1368                            // not be typed by `typ`. An example where this could cause an issue is
1369                            // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
1370                            // range of input types. (In theory, we could still do the optimization
1371                            // in this case by inserting appropriate casts, but this corner case is
1372                            // not worth the complication for now.)
1373                            // See https://github.com/MaterializeInc/database-issues/issues/9182
1374                            match (&mut **then, &mut **els) {
1375                                (
1376                                    MirScalarExpr::CallUnary { func: f1, expr: e1 },
1377                                    MirScalarExpr::CallUnary { func: f2, expr: e2 },
1378                                ) if f1 == f2
1379                                    && e1
1380                                        .typ(column_types)
1381                                        .try_union(&e2.typ(column_types))
1382                                        .is_ok() =>
1383                                {
1384                                    *e = cond
1385                                        .take()
1386                                        .if_then_else(e1.take(), e2.take())
1387                                        .call_unary(f1.clone());
1388                                }
1389                                (
1390                                    MirScalarExpr::CallBinary {
1391                                        func: f1,
1392                                        expr1: e1a,
1393                                        expr2: e2a,
1394                                    },
1395                                    MirScalarExpr::CallBinary {
1396                                        func: f2,
1397                                        expr1: e1b,
1398                                        expr2: e2b,
1399                                    },
1400                                ) if f1 == f2
1401                                    && e1a == e1b
1402                                    && e2a
1403                                        .typ(column_types)
1404                                        .try_union(&e2b.typ(column_types))
1405                                        .is_ok() =>
1406                                {
1407                                    *e = e1a.take().call_binary(
1408                                        cond.take().if_then_else(e2a.take(), e2b.take()),
1409                                        f1.clone(),
1410                                    );
1411                                }
1412                                (
1413                                    MirScalarExpr::CallBinary {
1414                                        func: f1,
1415                                        expr1: e1a,
1416                                        expr2: e2a,
1417                                    },
1418                                    MirScalarExpr::CallBinary {
1419                                        func: f2,
1420                                        expr1: e1b,
1421                                        expr2: e2b,
1422                                    },
1423                                ) if f1 == f2
1424                                    && e2a == e2b
1425                                    && e1a
1426                                        .typ(column_types)
1427                                        .try_union(&e1b.typ(column_types))
1428                                        .is_ok() =>
1429                                {
1430                                    *e = cond
1431                                        .take()
1432                                        .if_then_else(e1a.take(), e1b.take())
1433                                        .call_binary(e2a.take(), f1.clone());
1434                                }
1435                                _ => {}
1436                            }
1437                        }
1438                    }
1439                },
1440            );
1441        }
1442
1443        /* #region `reduce_list_create_list_index_literal` and helper functions */
1444
1445        fn list_create_type(list_create: &MirScalarExpr) -> SqlScalarType {
1446            if let MirScalarExpr::CallVariadic {
1447                func: VariadicFunc::ListCreate { elem_type: typ },
1448                ..
1449            } = list_create
1450            {
1451                (*typ).clone()
1452            } else {
1453                unreachable!()
1454            }
1455        }
1456
1457        fn is_list_create_call(expr: &MirScalarExpr) -> bool {
1458            matches!(
1459                expr,
1460                MirScalarExpr::CallVariadic {
1461                    func: VariadicFunc::ListCreate { .. },
1462                    ..
1463                }
1464            )
1465        }
1466
1467        /// Partial-evaluates a list indexing with a literal directly after a list creation.
1468        ///
1469        /// Multi-dimensional lists are handled by a single call to this function, with multiple
1470        /// elements in index_exprs (of which not all need to be literals), and nested ListCreates
1471        /// in list_create_to_reduce.
1472        ///
1473        /// # Examples
1474        ///
1475        /// `LIST[f1,f2][2]` --> `f2`.
1476        ///
1477        /// A multi-dimensional list, with only some of the indexes being literals:
1478        /// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
1479        ///
1480        /// See more examples in list.slt.
1481        fn reduce_list_create_list_index_literal(
1482            mut list_create_to_reduce: MirScalarExpr,
1483            mut index_exprs: Vec<MirScalarExpr>,
1484        ) -> MirScalarExpr {
1485            // We iterate over the index_exprs and remove literals, but keep non-literals.
1486            // When we encounter a non-literal, we need to dig into the nested ListCreates:
1487            // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
1488            // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
1489            // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
1490            // that are at the current level (except those that disappeared due to
1491            // literals at earlier levels), index into them with the literal, and change each
1492            // element in `list_create_mut_refs` to the result.
1493            // We also record mut refs to all the earlier `element_type` references that we have
1494            // seen in ListCreate calls, because when we process a literal index, we need to remove
1495            // one layer of list type from all these earlier ListCreate `element_type`s.
1496            let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
1497            let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
1498            let mut i = 0;
1499            while i < index_exprs.len()
1500                && list_create_mut_refs
1501                    .iter()
1502                    .all(|lc| is_list_create_call(lc))
1503            {
1504                if index_exprs[i].is_literal_ok() {
1505                    // We can remove this index.
1506                    let removed_index = index_exprs.remove(i);
1507                    let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
1508                        Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
1509                        _ => unreachable!(), // always an Int64, see plan_index_list
1510                    };
1511                    // For each list_create referenced by list_create_mut_refs, substitute it by its
1512                    // `index`th argument (or null).
1513                    for list_create in &mut list_create_mut_refs {
1514                        let list_create_args = match list_create {
1515                            MirScalarExpr::CallVariadic {
1516                                func: VariadicFunc::ListCreate { elem_type: _ },
1517                                exprs,
1518                            } => exprs,
1519                            _ => unreachable!(), // func cannot be anything else than a ListCreate
1520                        };
1521                        // ListIndex gives null on an out-of-bounds index
1522                        if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap()
1523                        {
1524                            let index: usize = index_i64.try_into().unwrap();
1525                            **list_create = list_create_args.swap_remove(index);
1526                        } else {
1527                            let typ = list_create_type(list_create);
1528                            **list_create = MirScalarExpr::literal_null(typ);
1529                        }
1530                    }
1531                    // Peel one layer off of each of the earlier element types.
1532                    for t in earlier_list_create_types.iter_mut() {
1533                        if let SqlScalarType::List {
1534                            element_type,
1535                            custom_id: _,
1536                        } = t
1537                        {
1538                            **t = *element_type.clone();
1539                            // These are not the same types anymore, so remove custom_ids all the
1540                            // way down.
1541                            let mut u = &mut **t;
1542                            while let SqlScalarType::List {
1543                                element_type,
1544                                custom_id,
1545                            } = u
1546                            {
1547                                *custom_id = None;
1548                                u = &mut **element_type;
1549                            }
1550                        } else {
1551                            unreachable!("already matched below");
1552                        }
1553                    }
1554                } else {
1555                    // We can't remove this index, so we can't reduce any of the ListCreates at this
1556                    // level. So we change list_create_mut_refs to refer to all the arguments of all
1557                    // the ListCreates currently referenced by list_create_mut_refs.
1558                    list_create_mut_refs = list_create_mut_refs
1559                        .into_iter()
1560                        .flat_map(|list_create| match list_create {
1561                            MirScalarExpr::CallVariadic {
1562                                func: VariadicFunc::ListCreate { elem_type },
1563                                exprs: list_create_args,
1564                            } => {
1565                                earlier_list_create_types.push(elem_type);
1566                                list_create_args
1567                            }
1568                            // func cannot be anything else than a ListCreate
1569                            _ => unreachable!(),
1570                        })
1571                        .collect();
1572                    i += 1; // next index_expr
1573                }
1574            }
1575            // If all list indexes have been evaluated, return the reduced expression.
1576            // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
1577            if index_exprs.is_empty() {
1578                assert_eq!(list_create_mut_refs.len(), 1);
1579                list_create_to_reduce
1580            } else {
1581                let mut exprs: Vec<MirScalarExpr> = vec![list_create_to_reduce];
1582                exprs.append(&mut index_exprs);
1583                MirScalarExpr::CallVariadic {
1584                    func: VariadicFunc::ListIndex,
1585                    exprs,
1586                }
1587            }
1588        }
1589
1590        /* #endregion */
1591    }
1592
1593    /// Decompose an IsNull expression into a disjunction of
1594    /// simpler expressions.
1595    ///
1596    /// Assumes that `self` is the expression inside of an IsNull.
1597    /// Returns `Some(expressions)` if the outer IsNull is to be
1598    /// replaced by some other expression. Note: if it returns
1599    /// None, it might still have mutated *self.
1600    fn decompose_is_null(&mut self) -> Option<MirScalarExpr> {
1601        // TODO: allow simplification of unmaterializable functions
1602
1603        match self {
1604            MirScalarExpr::CallUnary {
1605                func,
1606                expr: inner_expr,
1607            } => {
1608                if !func.introduces_nulls() {
1609                    if func.propagates_nulls() {
1610                        *self = inner_expr.take();
1611                        return self.decompose_is_null();
1612                    } else {
1613                        // Different from CallBinary and CallVariadic, because of determinism. See
1614                        // https://materializeinc.slack.com/archives/C01BE3RN82F/p1657644478517709
1615                        return Some(MirScalarExpr::literal_false());
1616                    }
1617                }
1618            }
1619            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1620                // (<expr1> <op> <expr2>) IS NULL can often be simplified to
1621                // (<expr1> IS NULL) OR (<expr2> IS NULL).
1622                if func.propagates_nulls() && !func.introduces_nulls() {
1623                    let expr1 = expr1.take().call_is_null();
1624                    let expr2 = expr2.take().call_is_null();
1625                    return Some(expr1.or(expr2));
1626                }
1627            }
1628            MirScalarExpr::CallVariadic { func, exprs } => {
1629                if func.propagates_nulls() && !func.introduces_nulls() {
1630                    let exprs = exprs.into_iter().map(|e| e.take().call_is_null()).collect();
1631                    return Some(MirScalarExpr::CallVariadic {
1632                        func: VariadicFunc::Or,
1633                        exprs,
1634                    });
1635                }
1636            }
1637            _ => {}
1638        }
1639
1640        None
1641    }
1642
1643    /// Flattens a chain of calls to associative variadic functions
1644    /// (For example: ORs or ANDs)
1645    pub fn flatten_associative(&mut self) {
1646        match self {
1647            MirScalarExpr::CallVariadic {
1648                exprs: outer_operands,
1649                func: outer_func,
1650            } if outer_func.is_associative() => {
1651                *outer_operands = outer_operands
1652                    .into_iter()
1653                    .flat_map(|o| {
1654                        if let MirScalarExpr::CallVariadic {
1655                            exprs: inner_operands,
1656                            func: inner_func,
1657                        } = o
1658                        {
1659                            if *inner_func == *outer_func {
1660                                mem::take(inner_operands)
1661                            } else {
1662                                vec![o.take()]
1663                            }
1664                        } else {
1665                            vec![o.take()]
1666                        }
1667                    })
1668                    .collect();
1669            }
1670            _ => {}
1671        }
1672    }
1673
1674    /* #region AND/OR canonicalization and transformations  */
1675
1676    /// Canonicalizes AND/OR, and does some straightforward simplifications
1677    fn reduce_and_canonicalize_and_or(&mut self) {
1678        // We do this until fixed point, because after undistribute_and_or calls us, it relies on
1679        // the property that self is not an 1-arg AND/OR. Just one application of our loop body
1680        // can't ensure this, because the application itself might create a 1-arg AND/OR.
1681        let mut old_self = MirScalarExpr::column(0);
1682        while old_self != *self {
1683            old_self = self.clone();
1684            match self {
1685                MirScalarExpr::CallVariadic {
1686                    func: func @ (VariadicFunc::And | VariadicFunc::Or),
1687                    exprs,
1688                } => {
1689                    // Canonically order elements so that various deduplications work better,
1690                    // e.g., in undistribute_and_or.
1691                    // Also, extract_equal_or_both_null_inner depends on the args being sorted.
1692                    exprs.sort();
1693
1694                    // x AND/OR x --> x
1695                    exprs.dedup(); // this also needs the above sorting
1696
1697                    if exprs.len() == 1 {
1698                        // AND/OR of 1 argument evaluates to that argument
1699                        *self = exprs.swap_remove(0);
1700                    } else if exprs.len() == 0 {
1701                        // AND/OR of 0 arguments evaluates to true/false
1702                        *self = func.unit_of_and_or();
1703                    } else if exprs.iter().any(|e| *e == func.zero_of_and_or()) {
1704                        // short-circuiting
1705                        *self = func.zero_of_and_or();
1706                    } else {
1707                        // a AND true --> a
1708                        // a OR false --> a
1709                        exprs.retain(|e| *e != func.unit_of_and_or());
1710                    }
1711                }
1712                _ => {}
1713            }
1714        }
1715    }
1716
1717    /// Transforms !(a && b) into !a || !b, and !(a || b) into !a && !b
1718    fn demorgans(&mut self) {
1719        if let MirScalarExpr::CallUnary {
1720            expr: inner,
1721            func: UnaryFunc::Not(func::Not),
1722        } = self
1723        {
1724            inner.flatten_associative();
1725            match &mut **inner {
1726                MirScalarExpr::CallVariadic {
1727                    func: inner_func @ (VariadicFunc::And | VariadicFunc::Or),
1728                    exprs,
1729                } => {
1730                    *inner_func = inner_func.switch_and_or();
1731                    *exprs = exprs.into_iter().map(|e| e.take().not()).collect();
1732                    *self = (*inner).take(); // Removes the outer not
1733                }
1734                _ => {}
1735            }
1736        }
1737    }
1738
1739    /// AND/OR undistribution (factoring out) to apply at each `MirScalarExpr`.
1740    ///
1741    /// This method attempts to apply one of the [distribution laws][distributivity]
1742    /// (in a direction opposite to the their name):
1743    /// ```text
1744    /// (a && b) || (a && c) --> a && (b || c)  // Undistribute-OR
1745    /// (a || b) && (a || c) --> a || (b && c)  // Undistribute-AND
1746    /// ```
1747    /// or one of their corresponding two [absorption law][absorption] special
1748    /// cases:
1749    /// ```text
1750    /// a || (a && c)  -->  a  // Absorb-OR
1751    /// a && (a || c)  -->  a  // Absorb-AND
1752    /// ```
1753    ///
1754    /// The method also works with more than 2 arguments at the top, e.g.
1755    /// ```text
1756    /// (a && b) || (a && c) || (a && d)  -->  a && (b || c || d)
1757    /// ```
1758    /// It can also factor out only a subset of the top arguments, e.g.
1759    /// ```text
1760    /// (a && b) || (a && c) || (d && e)  -->  (a && (b || c)) || (d && e)
1761    /// ```
1762    ///
1763    /// Note that sometimes there are two overlapping possibilities to factor
1764    /// out from, e.g.
1765    /// ```text
1766    /// (a && b) || (a && c) || (d && c)
1767    /// ```
1768    /// Here we can factor out `a` from from the 1. and 2. terms, or we can
1769    /// factor out `c` from the 2. and 3. terms. One of these might lead to
1770    /// more/better undistribution opportunities later, but we just pick one
1771    /// locally, because recursively trying out all of them would lead to
1772    /// exponential run time.
1773    ///
1774    /// The local heuristic is that we prefer a candidate that leads to an
1775    /// absorption, or if there is no such one then we simply pick the first. In
1776    /// case of multiple absorption candidates, it doesn't matter which one we
1777    /// pick, because applying an absorption cannot adversely effect the
1778    /// possibility of applying other absorptions.
1779    ///
1780    /// # Assumption
1781    ///
1782    /// Assumes that nested chains of AND/OR applications are flattened (this
1783    /// can be enforced with [`Self::flatten_associative`]).
1784    ///
1785    /// # Examples
1786    ///
1787    /// Absorb-OR:
1788    /// ```text
1789    /// a || (a && c) || (a && d)
1790    /// -->
1791    /// a && (true || c || d)
1792    /// -->
1793    /// a && true
1794    /// -->
1795    /// a
1796    /// ```
1797    /// Here only the first step is performed by this method. The rest is done
1798    /// by [`Self::reduce_and_canonicalize_and_or`] called after us in
1799    /// `reduce()`.
1800    ///
1801    /// [distributivity]: https://en.wikipedia.org/wiki/Distributive_property
1802    /// [absorption]: https://en.wikipedia.org/wiki/Absorption_law
1803    fn undistribute_and_or(&mut self) {
1804        // It wouldn't be strictly necessary to wrap this fn in this loop, because `reduce()` calls
1805        // us in a loop anyway. However, `reduce()` tries to do many other things, so the loop here
1806        // improves performance when there are several undistributions to apply in sequence, which
1807        // can occur in `CanonicalizeMfp` when undoing the DNF.
1808        let mut old_self = MirScalarExpr::column(0);
1809        while old_self != *self {
1810            old_self = self.clone();
1811            self.reduce_and_canonicalize_and_or(); // We don't want to deal with 1-arg AND/OR at the top
1812            if let MirScalarExpr::CallVariadic {
1813                exprs: outer_operands,
1814                func: outer_func @ (VariadicFunc::Or | VariadicFunc::And),
1815            } = self
1816            {
1817                let inner_func = outer_func.switch_and_or();
1818
1819                // Make sure that each outer operand is a call to inner_func, by wrapping in a 1-arg
1820                // call if necessary.
1821                outer_operands.iter_mut().for_each(|o| {
1822                    if !matches!(o, MirScalarExpr::CallVariadic {func: f, ..} if *f == inner_func) {
1823                        *o = MirScalarExpr::CallVariadic {
1824                            func: inner_func.clone(),
1825                            exprs: vec![o.take()],
1826                        };
1827                    }
1828                });
1829
1830                let mut inner_operands_refs: Vec<&mut Vec<MirScalarExpr>> = outer_operands
1831                    .iter_mut()
1832                    .map(|o| match o {
1833                        MirScalarExpr::CallVariadic { func: f, exprs } if *f == inner_func => exprs,
1834                        _ => unreachable!(), // the wrapping made sure that we'll get a match
1835                    })
1836                    .collect();
1837
1838                // Find inner operands to undistribute, i.e., which are in _all_ of the outer operands.
1839                let mut intersection = inner_operands_refs
1840                    .iter()
1841                    .map(|v| (*v).clone())
1842                    .reduce(|ops1, ops2| ops1.into_iter().filter(|e| ops2.contains(e)).collect())
1843                    .unwrap();
1844                intersection.sort();
1845                intersection.dedup();
1846
1847                if !intersection.is_empty() {
1848                    // Factor out the intersection from all the top-level args.
1849
1850                    // Remove the intersection from each inner operand vector.
1851                    inner_operands_refs
1852                        .iter_mut()
1853                        .for_each(|ops| (**ops).retain(|o| !intersection.contains(o)));
1854
1855                    // Simplify terms that now have only 0 or 1 args due to removing the intersection.
1856                    outer_operands
1857                        .iter_mut()
1858                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1859
1860                    // Add the intersection at the beginning
1861                    *self = MirScalarExpr::CallVariadic {
1862                        func: inner_func,
1863                        exprs: intersection.into_iter().chain_one(self.clone()).collect(),
1864                    };
1865                } else {
1866                    // If the intersection was empty, that means that there is nothing we can factor out
1867                    // from _all_ the top-level args. However, we might still find something to factor
1868                    // out from a subset of the top-level args. To find such an opportunity, we look for
1869                    // duplicates across all inner args, e.g. if we have
1870                    // `(...) OR (... AND `a` AND ...) OR (...) OR (... AND `a` AND ...)`
1871                    // then we'll find that `a` occurs in more than one top-level arg, so
1872                    // `indexes_to_undistribute` will point us to the 2. and 4. top-level args.
1873
1874                    // Create (inner_operand, index) pairs, where the index is the position in
1875                    // outer_operands
1876                    let all_inner_operands = inner_operands_refs
1877                        .iter()
1878                        .enumerate()
1879                        .flat_map(|(i, inner_vec)| inner_vec.iter().map(move |a| ((*a).clone(), i)))
1880                        .sorted()
1881                        .collect_vec();
1882
1883                    // Find inner operand expressions that occur in more than one top-level arg.
1884                    // Each inner vector in `undistribution_opportunities` will belong to one such inner
1885                    // operand expression, and it is a set of indexes pointing to top-level args where
1886                    // that inner operand occurs.
1887                    let undistribution_opportunities = all_inner_operands
1888                        .iter()
1889                        .chunk_by(|(a, _i)| a)
1890                        .into_iter()
1891                        .map(|(_a, g)| g.map(|(_a, i)| *i).sorted().dedup().collect_vec())
1892                        .filter(|g| g.len() > 1)
1893                        .collect_vec();
1894
1895                    // Choose one of the inner vectors from `undistribution_opportunities`.
1896                    let indexes_to_undistribute = undistribution_opportunities
1897                        .iter()
1898                        // Let's prefer index sets that directly lead to an absorption.
1899                        .find(|index_set| {
1900                            index_set
1901                                .iter()
1902                                .any(|i| inner_operands_refs.get(*i).unwrap().len() == 1)
1903                        })
1904                        // If we didn't find any absorption, then any index set will do.
1905                        .or_else(|| undistribution_opportunities.first())
1906                        .cloned();
1907
1908                    // In any case, undo the 1-arg wrapping that we did at the beginning.
1909                    outer_operands
1910                        .iter_mut()
1911                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1912
1913                    if let Some(indexes_to_undistribute) = indexes_to_undistribute {
1914                        // Found something to undistribute from a subset of the outer operands.
1915                        // We temporarily remove these from outer_operands, call ourselves on it, and
1916                        // then push back the result.
1917                        let mut undistribute_from = MirScalarExpr::CallVariadic {
1918                            func: outer_func.clone(),
1919                            exprs: swap_remove_multiple(outer_operands, indexes_to_undistribute),
1920                        };
1921                        // By construction, the recursive call is guaranteed to hit
1922                        // the `!intersection.is_empty()` branch.
1923                        undistribute_from.undistribute_and_or();
1924                        // Append the undistributed result to outer operands that were not included in
1925                        // indexes_to_undistribute.
1926                        outer_operands.push(undistribute_from);
1927                    }
1928                }
1929            }
1930        }
1931    }
1932
1933    /* #endregion */
1934
1935    /// Adds any columns that *must* be non-Null for `self` to be non-Null.
1936    pub fn non_null_requirements(&self, columns: &mut BTreeSet<usize>) {
1937        match self {
1938            MirScalarExpr::Column(col, _name) => {
1939                columns.insert(*col);
1940            }
1941            MirScalarExpr::Literal(..) => {}
1942            MirScalarExpr::CallUnmaterializable(_) => (),
1943            MirScalarExpr::CallUnary { func, expr } => {
1944                if func.propagates_nulls() {
1945                    expr.non_null_requirements(columns);
1946                }
1947            }
1948            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1949                if func.propagates_nulls() {
1950                    expr1.non_null_requirements(columns);
1951                    expr2.non_null_requirements(columns);
1952                }
1953            }
1954            MirScalarExpr::CallVariadic { func, exprs } => {
1955                if func.propagates_nulls() {
1956                    for expr in exprs {
1957                        expr.non_null_requirements(columns);
1958                    }
1959                }
1960            }
1961            MirScalarExpr::If { .. } => (),
1962        }
1963    }
1964
1965    pub fn typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
1966        match self {
1967            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
1968            MirScalarExpr::Literal(_, typ) => typ.clone(),
1969            MirScalarExpr::CallUnmaterializable(func) => func.output_type(),
1970            MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)),
1971            MirScalarExpr::CallBinary { expr1, expr2, func } => {
1972                func.output_type(expr1.typ(column_types), expr2.typ(column_types))
1973            }
1974            MirScalarExpr::CallVariadic { exprs, func } => {
1975                func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect())
1976            }
1977            MirScalarExpr::If { cond: _, then, els } => {
1978                let then_type = then.typ(column_types);
1979                let else_type = els.typ(column_types);
1980                then_type.union(&else_type)
1981            }
1982        }
1983    }
1984
1985    pub fn repr_typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
1986        match self {
1987            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
1988            MirScalarExpr::Literal(_, typ) => ReprColumnType::from(typ),
1989            MirScalarExpr::CallUnmaterializable(func) => ReprColumnType::from(&func.output_type()),
1990            MirScalarExpr::CallUnary { expr, func } => ReprColumnType::from(
1991                &func.output_type(SqlColumnType::from_repr(&expr.repr_typ(column_types))),
1992            ),
1993            MirScalarExpr::CallBinary { expr1, expr2, func } => {
1994                ReprColumnType::from(&func.output_type(
1995                    SqlColumnType::from_repr(&expr1.repr_typ(column_types)),
1996                    SqlColumnType::from_repr(&expr2.repr_typ(column_types)),
1997                ))
1998            }
1999            MirScalarExpr::CallVariadic { exprs, func } => ReprColumnType::from(
2000                &func.output_type(
2001                    exprs
2002                        .iter()
2003                        .map(|e| SqlColumnType::from_repr(&e.repr_typ(column_types)))
2004                        .collect(),
2005                ),
2006            ),
2007            MirScalarExpr::If { cond: _, then, els } => {
2008                let then_type = then.repr_typ(column_types);
2009                let else_type = els.repr_typ(column_types);
2010                then_type.union(&else_type).unwrap()
2011            }
2012        }
2013    }
2014
2015    pub fn eval<'a>(
2016        &'a self,
2017        datums: &[Datum<'a>],
2018        temp_storage: &'a RowArena,
2019    ) -> Result<Datum<'a>, EvalError> {
2020        match self {
2021            MirScalarExpr::Column(index, _name) => Ok(datums[*index]),
2022            MirScalarExpr::Literal(res, _column_type) => match res {
2023                Ok(row) => Ok(row.unpack_first()),
2024                Err(e) => Err(e.clone()),
2025            },
2026            // Unmaterializable functions must be transformed away before
2027            // evaluation. Their purpose is as a placeholder for data that is
2028            // not known at plan time but can be inlined before runtime.
2029            MirScalarExpr::CallUnmaterializable(x) => Err(EvalError::Internal(
2030                format!("cannot evaluate unmaterializable function: {:?}", x).into(),
2031            )),
2032            MirScalarExpr::CallUnary { func, expr } => func.eval(datums, temp_storage, expr),
2033            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2034                func.eval(datums, temp_storage, expr1, expr2)
2035            }
2036            MirScalarExpr::CallVariadic { func, exprs } => func.eval(datums, temp_storage, exprs),
2037            MirScalarExpr::If { cond, then, els } => match cond.eval(datums, temp_storage)? {
2038                Datum::True => then.eval(datums, temp_storage),
2039                Datum::False | Datum::Null => els.eval(datums, temp_storage),
2040                d => Err(EvalError::Internal(
2041                    format!("if condition evaluated to non-boolean datum: {:?}", d).into(),
2042                )),
2043            },
2044        }
2045    }
2046
2047    /// True iff the expression contains
2048    /// `UnmaterializableFunc::MzNow`.
2049    pub fn contains_temporal(&self) -> bool {
2050        let mut contains = false;
2051        self.visit_pre(|e| {
2052            if let MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2053                contains = true;
2054            }
2055        });
2056        contains
2057    }
2058
2059    /// True iff the expression contains an `UnmaterializableFunc`.
2060    pub fn contains_unmaterializable(&self) -> bool {
2061        let mut contains = false;
2062        self.visit_pre(|e| {
2063            if let MirScalarExpr::CallUnmaterializable(_) = e {
2064                contains = true;
2065            }
2066        });
2067        contains
2068    }
2069
2070    /// True iff the expression contains an `UnmaterializableFunc` that is not in the `exceptions`
2071    /// list.
2072    pub fn contains_unmaterializable_except(&self, exceptions: &[UnmaterializableFunc]) -> bool {
2073        let mut contains = false;
2074        self.visit_pre(|e| match e {
2075            MirScalarExpr::CallUnmaterializable(f) if !exceptions.contains(f) => contains = true,
2076            _ => (),
2077        });
2078        contains
2079    }
2080
2081    /// True iff the expression contains a `Column`.
2082    pub fn contains_column(&self) -> bool {
2083        let mut contains = false;
2084        self.visit_pre(|e| {
2085            if let MirScalarExpr::Column(_col, _name) = e {
2086                contains = true;
2087            }
2088        });
2089        contains
2090    }
2091
2092    /// True iff the expression contains a `Dummy`.
2093    pub fn contains_dummy(&self) -> bool {
2094        let mut contains = false;
2095        self.visit_pre(|e| {
2096            if let MirScalarExpr::Literal(row, _) = e {
2097                if let Ok(row) = row {
2098                    contains |= row.iter().any(|d| d.contains_dummy());
2099                }
2100            }
2101        });
2102        contains
2103    }
2104
2105    /// The size of the expression as a tree.
2106    pub fn size(&self) -> usize {
2107        let mut size = 0;
2108        self.visit_pre(&mut |_: &MirScalarExpr| {
2109            size += 1;
2110        });
2111        size
2112    }
2113}
2114
2115impl MirScalarExpr {
2116    /// True iff evaluation could possibly error on non-error input `Datum`.
2117    pub fn could_error(&self) -> bool {
2118        match self {
2119            MirScalarExpr::Column(_col, _name) => false,
2120            MirScalarExpr::Literal(row, ..) => row.is_err(),
2121            MirScalarExpr::CallUnmaterializable(_) => true,
2122            MirScalarExpr::CallUnary { func, expr } => func.could_error() || expr.could_error(),
2123            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2124                func.could_error() || expr1.could_error() || expr2.could_error()
2125            }
2126            MirScalarExpr::CallVariadic { func, exprs } => {
2127                func.could_error() || exprs.iter().any(|e| e.could_error())
2128            }
2129            MirScalarExpr::If { cond, then, els } => {
2130                cond.could_error() || then.could_error() || els.could_error()
2131            }
2132        }
2133    }
2134}
2135
2136impl VisitChildren<Self> for MirScalarExpr {
2137    fn visit_children<F>(&self, mut f: F)
2138    where
2139        F: FnMut(&Self),
2140    {
2141        use MirScalarExpr::*;
2142        match self {
2143            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2144            CallUnary { expr, .. } => {
2145                f(expr);
2146            }
2147            CallBinary { expr1, expr2, .. } => {
2148                f(expr1);
2149                f(expr2);
2150            }
2151            CallVariadic { exprs, .. } => {
2152                for expr in exprs {
2153                    f(expr);
2154                }
2155            }
2156            If { cond, then, els } => {
2157                f(cond);
2158                f(then);
2159                f(els);
2160            }
2161        }
2162    }
2163
2164    fn visit_mut_children<F>(&mut self, mut f: F)
2165    where
2166        F: FnMut(&mut Self),
2167    {
2168        use MirScalarExpr::*;
2169        match self {
2170            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2171            CallUnary { expr, .. } => {
2172                f(expr);
2173            }
2174            CallBinary { expr1, expr2, .. } => {
2175                f(expr1);
2176                f(expr2);
2177            }
2178            CallVariadic { exprs, .. } => {
2179                for expr in exprs {
2180                    f(expr);
2181                }
2182            }
2183            If { cond, then, els } => {
2184                f(cond);
2185                f(then);
2186                f(els);
2187            }
2188        }
2189    }
2190
2191    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2192    where
2193        F: FnMut(&Self) -> Result<(), E>,
2194        E: From<RecursionLimitError>,
2195    {
2196        use MirScalarExpr::*;
2197        match self {
2198            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2199            CallUnary { expr, .. } => {
2200                f(expr)?;
2201            }
2202            CallBinary { expr1, expr2, .. } => {
2203                f(expr1)?;
2204                f(expr2)?;
2205            }
2206            CallVariadic { exprs, .. } => {
2207                for expr in exprs {
2208                    f(expr)?;
2209                }
2210            }
2211            If { cond, then, els } => {
2212                f(cond)?;
2213                f(then)?;
2214                f(els)?;
2215            }
2216        }
2217        Ok(())
2218    }
2219
2220    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2221    where
2222        F: FnMut(&mut Self) -> Result<(), E>,
2223        E: From<RecursionLimitError>,
2224    {
2225        use MirScalarExpr::*;
2226        match self {
2227            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2228            CallUnary { expr, .. } => {
2229                f(expr)?;
2230            }
2231            CallBinary { expr1, expr2, .. } => {
2232                f(expr1)?;
2233                f(expr2)?;
2234            }
2235            CallVariadic { exprs, .. } => {
2236                for expr in exprs {
2237                    f(expr)?;
2238                }
2239            }
2240            If { cond, then, els } => {
2241                f(cond)?;
2242                f(then)?;
2243                f(els)?;
2244            }
2245        }
2246        Ok(())
2247    }
2248}
2249
2250impl MirScalarExpr {
2251    /// Iterates through references to child expressions.
2252    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2253        let mut first = None;
2254        let mut second = None;
2255        let mut third = None;
2256        let mut variadic = None;
2257
2258        use MirScalarExpr::*;
2259        match self {
2260            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2261            CallUnary { expr, .. } => {
2262                first = Some(&**expr);
2263            }
2264            CallBinary { expr1, expr2, .. } => {
2265                first = Some(&**expr1);
2266                second = Some(&**expr2);
2267            }
2268            CallVariadic { exprs, .. } => {
2269                variadic = Some(exprs);
2270            }
2271            If { cond, then, els } => {
2272                first = Some(&**cond);
2273                second = Some(&**then);
2274                third = Some(&**els);
2275            }
2276        }
2277
2278        first
2279            .into_iter()
2280            .chain(second)
2281            .chain(third)
2282            .chain(variadic.into_iter().flatten())
2283    }
2284
2285    /// Iterates through mutable references to child expressions.
2286    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2287        let mut first = None;
2288        let mut second = None;
2289        let mut third = None;
2290        let mut variadic = None;
2291
2292        use MirScalarExpr::*;
2293        match self {
2294            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2295            CallUnary { expr, .. } => {
2296                first = Some(&mut **expr);
2297            }
2298            CallBinary { expr1, expr2, .. } => {
2299                first = Some(&mut **expr1);
2300                second = Some(&mut **expr2);
2301            }
2302            CallVariadic { exprs, .. } => {
2303                variadic = Some(exprs);
2304            }
2305            If { cond, then, els } => {
2306                first = Some(&mut **cond);
2307                second = Some(&mut **then);
2308                third = Some(&mut **els);
2309            }
2310        }
2311
2312        first
2313            .into_iter()
2314            .chain(second)
2315            .chain(third)
2316            .chain(variadic.into_iter().flatten())
2317    }
2318
2319    /// Visits all subexpressions in DFS preorder.
2320    pub fn visit_pre<F>(&self, mut f: F)
2321    where
2322        F: FnMut(&Self),
2323    {
2324        let mut worklist = vec![self];
2325        while let Some(e) = worklist.pop() {
2326            f(e);
2327            worklist.extend(e.children().rev());
2328        }
2329    }
2330
2331    /// Iterative pre-order visitor.
2332    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2333        let mut worklist = vec![self];
2334        while let Some(expr) = worklist.pop() {
2335            f(expr);
2336            worklist.extend(expr.children_mut().rev());
2337        }
2338    }
2339}
2340
2341/// Filter characteristics that are used for ordering join inputs.
2342/// This can be created for a `Vec<MirScalarExpr>`, which represents an AND of predicates.
2343///
2344/// The fields are ordered based on heuristic assumptions about their typical selectivity, so that
2345/// Ord gives the right ordering for join inputs. Bigger is better, i.e., will tend to come earlier
2346/// than other inputs.
2347#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
2348pub struct FilterCharacteristics {
2349    // `<expr> = <literal>` appears in the filter.
2350    // Excludes cases where NOT appears anywhere above the literal equality.
2351    literal_equality: bool,
2352    // (Assuming a random string of lower-case characters, `LIKE 'a%'` has a selectivity of 1/26.)
2353    like: bool,
2354    is_null: bool,
2355    // Number of Vec elements that involve inequality predicates. (A BETWEEN is represented as two
2356    // inequality predicates.)
2357    // Excludes cases where NOT appears around the literal inequality.
2358    // Note that for inequality predicates, some databases assume 1/3 selectivity in the absence of
2359    // concrete statistics.
2360    literal_inequality: usize,
2361    /// Any filter, except ones involving `IS NOT NULL`, because those are too common.
2362    /// Can be true by itself, or any other field being true can also make this true.
2363    /// `NOT LIKE` is only in this category.
2364    /// `!=` is only in this category.
2365    /// `NOT (a = b)` is turned into `!=` by `reduce` before us!
2366    any_filter: bool,
2367}
2368
2369impl BitOrAssign for FilterCharacteristics {
2370    fn bitor_assign(&mut self, rhs: Self) {
2371        self.literal_equality |= rhs.literal_equality;
2372        self.like |= rhs.like;
2373        self.is_null |= rhs.is_null;
2374        self.literal_inequality += rhs.literal_inequality;
2375        self.any_filter |= rhs.any_filter;
2376    }
2377}
2378
2379impl FilterCharacteristics {
2380    pub fn none() -> FilterCharacteristics {
2381        FilterCharacteristics {
2382            literal_equality: false,
2383            like: false,
2384            is_null: false,
2385            literal_inequality: 0,
2386            any_filter: false,
2387        }
2388    }
2389
2390    pub fn explain(&self) -> String {
2391        let mut e = "".to_owned();
2392        if self.literal_equality {
2393            e.push_str("e");
2394        }
2395        if self.like {
2396            e.push_str("l");
2397        }
2398        if self.is_null {
2399            e.push_str("n");
2400        }
2401        for _ in 0..self.literal_inequality {
2402            e.push_str("i");
2403        }
2404        if self.any_filter {
2405            e.push_str("f");
2406        }
2407        e
2408    }
2409
2410    pub fn filter_characteristics(
2411        filters: &Vec<MirScalarExpr>,
2412    ) -> Result<FilterCharacteristics, RecursionLimitError> {
2413        let mut literal_equality = false;
2414        let mut like = false;
2415        let mut is_null = false;
2416        let mut literal_inequality = 0;
2417        let mut any_filter = false;
2418        filters.iter().try_for_each(|f| {
2419            let mut literal_inequality_in_current_filter = false;
2420            let mut is_not_null_in_current_filter = false;
2421            f.visit_pre_with_context(
2422                false,
2423                &mut |not_in_parent_chain, expr| {
2424                    not_in_parent_chain
2425                        || matches!(
2426                            expr,
2427                            MirScalarExpr::CallUnary {
2428                                func: UnaryFunc::Not(func::Not),
2429                                ..
2430                            }
2431                        )
2432                },
2433                &mut |not_in_parent_chain, expr| {
2434                    if !not_in_parent_chain {
2435                        if expr.any_expr_eq_literal().is_some() {
2436                            literal_equality = true;
2437                        }
2438                        if expr.any_expr_ineq_literal() {
2439                            literal_inequality_in_current_filter = true;
2440                        }
2441                        if matches!(
2442                            expr,
2443                            MirScalarExpr::CallUnary {
2444                                func: UnaryFunc::IsLikeMatch(_),
2445                                ..
2446                            }
2447                        ) {
2448                            like = true;
2449                        }
2450                    };
2451                    if matches!(
2452                        expr,
2453                        MirScalarExpr::CallUnary {
2454                            func: UnaryFunc::IsNull(crate::func::IsNull),
2455                            ..
2456                        }
2457                    ) {
2458                        if *not_in_parent_chain {
2459                            is_not_null_in_current_filter = true;
2460                        } else {
2461                            is_null = true;
2462                        }
2463                    }
2464                },
2465            )?;
2466            if literal_inequality_in_current_filter {
2467                literal_inequality += 1;
2468            }
2469            if !is_not_null_in_current_filter {
2470                // We want to ignore `IS NOT NULL` for `any_filter`.
2471                any_filter = true;
2472            }
2473            Ok(())
2474        })?;
2475        Ok(FilterCharacteristics {
2476            literal_equality,
2477            like,
2478            is_null,
2479            literal_inequality,
2480            any_filter,
2481        })
2482    }
2483
2484    pub fn add_literal_equality(&mut self) {
2485        self.literal_equality = true;
2486    }
2487
2488    pub fn worst_case_scaling_factor(&self) -> f64 {
2489        let mut factor = 1.0;
2490
2491        if self.literal_equality {
2492            factor *= 0.1;
2493        }
2494
2495        if self.is_null {
2496            factor *= 0.1;
2497        }
2498
2499        if self.literal_inequality >= 2 {
2500            factor *= 0.25;
2501        } else if self.literal_inequality == 1 {
2502            factor *= 0.33;
2503        }
2504
2505        // catch various negated filters, treat them pessimistically
2506        if !(self.literal_equality || self.is_null || self.literal_inequality > 0)
2507            && self.any_filter
2508        {
2509            factor *= 0.9;
2510        }
2511
2512        factor
2513    }
2514}
2515
2516#[derive(
2517    Arbitrary,
2518    Ord,
2519    PartialOrd,
2520    Copy,
2521    Clone,
2522    Debug,
2523    Eq,
2524    PartialEq,
2525    Serialize,
2526    Deserialize,
2527    Hash,
2528    MzReflect,
2529)]
2530pub enum DomainLimit {
2531    None,
2532    Inclusive(i64),
2533    Exclusive(i64),
2534}
2535
2536impl RustType<ProtoDomainLimit> for DomainLimit {
2537    fn into_proto(&self) -> ProtoDomainLimit {
2538        use proto_domain_limit::Kind::*;
2539        let kind = match self {
2540            DomainLimit::None => None(()),
2541            DomainLimit::Inclusive(v) => Inclusive(*v),
2542            DomainLimit::Exclusive(v) => Exclusive(*v),
2543        };
2544        ProtoDomainLimit { kind: Some(kind) }
2545    }
2546
2547    fn from_proto(proto: ProtoDomainLimit) -> Result<Self, TryFromProtoError> {
2548        use proto_domain_limit::Kind::*;
2549        if let Some(kind) = proto.kind {
2550            match kind {
2551                None(()) => Ok(DomainLimit::None),
2552                Inclusive(v) => Ok(DomainLimit::Inclusive(v)),
2553                Exclusive(v) => Ok(DomainLimit::Exclusive(v)),
2554            }
2555        } else {
2556            Err(TryFromProtoError::missing_field("ProtoDomainLimit::kind"))
2557        }
2558    }
2559}
2560
2561#[derive(
2562    Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect,
2563)]
2564pub enum EvalError {
2565    CharacterNotValidForEncoding(i32),
2566    CharacterTooLargeForEncoding(i32),
2567    DateBinOutOfRange(Box<str>),
2568    DivisionByZero,
2569    Unsupported {
2570        feature: Box<str>,
2571        discussion_no: Option<usize>,
2572    },
2573    FloatOverflow,
2574    FloatUnderflow,
2575    NumericFieldOverflow,
2576    Float32OutOfRange(Box<str>),
2577    Float64OutOfRange(Box<str>),
2578    Int16OutOfRange(Box<str>),
2579    Int32OutOfRange(Box<str>),
2580    Int64OutOfRange(Box<str>),
2581    UInt16OutOfRange(Box<str>),
2582    UInt32OutOfRange(Box<str>),
2583    UInt64OutOfRange(Box<str>),
2584    MzTimestampOutOfRange(Box<str>),
2585    MzTimestampStepOverflow,
2586    OidOutOfRange(Box<str>),
2587    IntervalOutOfRange(Box<str>),
2588    TimestampCannotBeNan,
2589    TimestampOutOfRange,
2590    DateOutOfRange,
2591    CharOutOfRange,
2592    IndexOutOfRange {
2593        provided: i32,
2594        // The last valid index position, i.e. `v.len() - 1`
2595        valid_end: i32,
2596    },
2597    InvalidBase64Equals,
2598    InvalidBase64Symbol(char),
2599    InvalidBase64EndSequence,
2600    InvalidTimezone(Box<str>),
2601    InvalidTimezoneInterval,
2602    InvalidTimezoneConversion,
2603    InvalidIanaTimezoneId(Box<str>),
2604    InvalidLayer {
2605        max_layer: usize,
2606        val: i64,
2607    },
2608    InvalidArray(InvalidArrayError),
2609    InvalidEncodingName(Box<str>),
2610    InvalidHashAlgorithm(Box<str>),
2611    InvalidByteSequence {
2612        byte_sequence: Box<str>,
2613        encoding_name: Box<str>,
2614    },
2615    InvalidJsonbCast {
2616        from: Box<str>,
2617        to: Box<str>,
2618    },
2619    InvalidRegex(Box<str>),
2620    InvalidRegexFlag(char),
2621    InvalidParameterValue(Box<str>),
2622    InvalidDatePart(Box<str>),
2623    KeyCannotBeNull,
2624    NegSqrt,
2625    NegLimit,
2626    NullCharacterNotPermitted,
2627    UnknownUnits(Box<str>),
2628    UnsupportedUnits(Box<str>, Box<str>),
2629    UnterminatedLikeEscapeSequence,
2630    Parse(ParseError),
2631    ParseHex(ParseHexError),
2632    Internal(Box<str>),
2633    InfinityOutOfDomain(Box<str>),
2634    NegativeOutOfDomain(Box<str>),
2635    ZeroOutOfDomain(Box<str>),
2636    OutOfDomain(DomainLimit, DomainLimit, Box<str>),
2637    ComplexOutOfRange(Box<str>),
2638    MultipleRowsFromSubquery,
2639    Undefined(Box<str>),
2640    LikePatternTooLong,
2641    LikeEscapeTooLong,
2642    StringValueTooLong {
2643        target_type: Box<str>,
2644        length: usize,
2645    },
2646    MultidimensionalArrayRemovalNotSupported,
2647    IncompatibleArrayDimensions {
2648        dims: Option<(usize, usize)>,
2649    },
2650    TypeFromOid(Box<str>),
2651    InvalidRange(InvalidRangeError),
2652    InvalidRoleId(Box<str>),
2653    InvalidPrivileges(Box<str>),
2654    LetRecLimitExceeded(Box<str>),
2655    MultiDimensionalArraySearch,
2656    MustNotBeNull(Box<str>),
2657    InvalidIdentifier {
2658        ident: Box<str>,
2659        detail: Option<Box<str>>,
2660    },
2661    ArrayFillWrongArraySubscripts,
2662    // TODO: propagate this check more widely throughout the expr crate
2663    MaxArraySizeExceeded(usize),
2664    DateDiffOverflow {
2665        unit: Box<str>,
2666        a: Box<str>,
2667        b: Box<str>,
2668    },
2669    // The error for ErrorIfNull; this should not be used in other contexts as a generic error
2670    // printer.
2671    IfNullError(Box<str>),
2672    LengthTooLarge,
2673    AclArrayNullElement,
2674    MzAclArrayNullElement,
2675    PrettyError(Box<str>),
2676}
2677
2678impl fmt::Display for EvalError {
2679    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2680        match self {
2681            EvalError::CharacterNotValidForEncoding(v) => {
2682                write!(f, "requested character not valid for encoding: {v}")
2683            }
2684            EvalError::CharacterTooLargeForEncoding(v) => {
2685                write!(f, "requested character too large for encoding: {v}")
2686            }
2687            EvalError::DateBinOutOfRange(message) => f.write_str(message),
2688            EvalError::DivisionByZero => f.write_str("division by zero"),
2689            EvalError::Unsupported {
2690                feature,
2691                discussion_no,
2692            } => {
2693                write!(f, "{} not yet supported", feature)?;
2694                if let Some(discussion_no) = discussion_no {
2695                    write!(
2696                        f,
2697                        ", see https://github.com/MaterializeInc/materialize/discussions/{} for more details",
2698                        discussion_no
2699                    )?;
2700                }
2701                Ok(())
2702            }
2703            EvalError::FloatOverflow => f.write_str("value out of range: overflow"),
2704            EvalError::FloatUnderflow => f.write_str("value out of range: underflow"),
2705            EvalError::NumericFieldOverflow => f.write_str("numeric field overflow"),
2706            EvalError::Float32OutOfRange(val) => write!(f, "{} real out of range", val.quoted()),
2707            EvalError::Float64OutOfRange(val) => {
2708                write!(f, "{} double precision out of range", val.quoted())
2709            }
2710            EvalError::Int16OutOfRange(val) => write!(f, "{} smallint out of range", val.quoted()),
2711            EvalError::Int32OutOfRange(val) => write!(f, "{} integer out of range", val.quoted()),
2712            EvalError::Int64OutOfRange(val) => write!(f, "{} bigint out of range", val.quoted()),
2713            EvalError::UInt16OutOfRange(val) => write!(f, "{} uint2 out of range", val.quoted()),
2714            EvalError::UInt32OutOfRange(val) => write!(f, "{} uint4 out of range", val.quoted()),
2715            EvalError::UInt64OutOfRange(val) => write!(f, "{} uint8 out of range", val.quoted()),
2716            EvalError::MzTimestampOutOfRange(val) => {
2717                write!(f, "{} mz_timestamp out of range", val.quoted())
2718            }
2719            EvalError::MzTimestampStepOverflow => f.write_str("step mz_timestamp overflow"),
2720            EvalError::OidOutOfRange(val) => write!(f, "{} OID out of range", val.quoted()),
2721            EvalError::IntervalOutOfRange(val) => {
2722                write!(f, "{} interval out of range", val.quoted())
2723            }
2724            EvalError::TimestampCannotBeNan => f.write_str("timestamp cannot be NaN"),
2725            EvalError::TimestampOutOfRange => f.write_str("timestamp out of range"),
2726            EvalError::DateOutOfRange => f.write_str("date out of range"),
2727            EvalError::CharOutOfRange => f.write_str("\"char\" out of range"),
2728            EvalError::IndexOutOfRange {
2729                provided,
2730                valid_end,
2731            } => write!(f, "index {provided} out of valid range, 0..{valid_end}",),
2732            EvalError::InvalidBase64Equals => {
2733                f.write_str("unexpected \"=\" while decoding base64 sequence")
2734            }
2735            EvalError::InvalidBase64Symbol(c) => write!(
2736                f,
2737                "invalid symbol \"{}\" found while decoding base64 sequence",
2738                c.escape_default()
2739            ),
2740            EvalError::InvalidBase64EndSequence => f.write_str("invalid base64 end sequence"),
2741            EvalError::InvalidJsonbCast { from, to } => {
2742                write!(f, "cannot cast jsonb {} to type {}", from, to)
2743            }
2744            EvalError::InvalidTimezone(tz) => write!(f, "invalid time zone '{}'", tz),
2745            EvalError::InvalidTimezoneInterval => {
2746                f.write_str("timezone interval must not contain months or years")
2747            }
2748            EvalError::InvalidTimezoneConversion => f.write_str("invalid timezone conversion"),
2749            EvalError::InvalidIanaTimezoneId(tz) => {
2750                write!(f, "invalid IANA Time Zone Database identifier: '{}'", tz)
2751            }
2752            EvalError::InvalidLayer { max_layer, val } => write!(
2753                f,
2754                "invalid layer: {}; must use value within [1, {}]",
2755                val, max_layer
2756            ),
2757            EvalError::InvalidArray(e) => e.fmt(f),
2758            EvalError::InvalidEncodingName(name) => write!(f, "invalid encoding name '{}'", name),
2759            EvalError::InvalidHashAlgorithm(alg) => write!(f, "invalid hash algorithm '{}'", alg),
2760            EvalError::InvalidByteSequence {
2761                byte_sequence,
2762                encoding_name,
2763            } => write!(
2764                f,
2765                "invalid byte sequence '{}' for encoding '{}'",
2766                byte_sequence, encoding_name
2767            ),
2768            EvalError::InvalidDatePart(part) => write!(f, "invalid datepart {}", part.quoted()),
2769            EvalError::KeyCannotBeNull => f.write_str("key cannot be null"),
2770            EvalError::NegSqrt => f.write_str("cannot take square root of a negative number"),
2771            EvalError::NegLimit => f.write_str("LIMIT must not be negative"),
2772            EvalError::NullCharacterNotPermitted => f.write_str("null character not permitted"),
2773            EvalError::InvalidRegex(e) => write!(f, "invalid regular expression: {}", e),
2774            EvalError::InvalidRegexFlag(c) => write!(f, "invalid regular expression flag: {}", c),
2775            EvalError::InvalidParameterValue(s) => f.write_str(s),
2776            EvalError::UnknownUnits(units) => write!(f, "unit '{}' not recognized", units),
2777            EvalError::UnsupportedUnits(units, typ) => {
2778                write!(f, "unit '{}' not supported for type {}", units, typ)
2779            }
2780            EvalError::UnterminatedLikeEscapeSequence => {
2781                f.write_str("unterminated escape sequence in LIKE")
2782            }
2783            EvalError::Parse(e) => e.fmt(f),
2784            EvalError::PrettyError(e) => e.fmt(f),
2785            EvalError::ParseHex(e) => e.fmt(f),
2786            EvalError::Internal(s) => write!(f, "internal error: {}", s),
2787            EvalError::InfinityOutOfDomain(s) => {
2788                write!(f, "function {} is only defined for finite arguments", s)
2789            }
2790            EvalError::NegativeOutOfDomain(s) => {
2791                write!(f, "function {} is not defined for negative numbers", s)
2792            }
2793            EvalError::ZeroOutOfDomain(s) => {
2794                write!(f, "function {} is not defined for zero", s)
2795            }
2796            EvalError::OutOfDomain(lower, upper, s) => {
2797                use DomainLimit::*;
2798                write!(f, "function {s} is defined for numbers ")?;
2799                match (lower, upper) {
2800                    (Inclusive(n), None) => write!(f, "greater than or equal to {n}"),
2801                    (Exclusive(n), None) => write!(f, "greater than {n}"),
2802                    (None, Inclusive(n)) => write!(f, "less than or equal to {n}"),
2803                    (None, Exclusive(n)) => write!(f, "less than {n}"),
2804                    (Inclusive(lo), Inclusive(hi)) => write!(f, "between {lo} and {hi} inclusive"),
2805                    (Exclusive(lo), Exclusive(hi)) => write!(f, "between {lo} and {hi} exclusive"),
2806                    (Inclusive(lo), Exclusive(hi)) => {
2807                        write!(f, "between {lo} inclusive and {hi} exclusive")
2808                    }
2809                    (Exclusive(lo), Inclusive(hi)) => {
2810                        write!(f, "between {lo} exclusive and {hi} inclusive")
2811                    }
2812                    (None, None) => panic!("invalid domain error"),
2813                }
2814            }
2815            EvalError::ComplexOutOfRange(s) => {
2816                write!(f, "function {} cannot return complex numbers", s)
2817            }
2818            EvalError::MultipleRowsFromSubquery => {
2819                write!(f, "more than one record produced in subquery")
2820            }
2821            EvalError::Undefined(s) => {
2822                write!(f, "{} is undefined", s)
2823            }
2824            EvalError::LikePatternTooLong => {
2825                write!(f, "LIKE pattern exceeds maximum length")
2826            }
2827            EvalError::LikeEscapeTooLong => {
2828                write!(f, "invalid escape string")
2829            }
2830            EvalError::StringValueTooLong {
2831                target_type,
2832                length,
2833            } => {
2834                write!(f, "value too long for type {}({})", target_type, length)
2835            }
2836            EvalError::MultidimensionalArrayRemovalNotSupported => {
2837                write!(
2838                    f,
2839                    "removing elements from multidimensional arrays is not supported"
2840                )
2841            }
2842            EvalError::IncompatibleArrayDimensions { dims: _ } => {
2843                write!(f, "cannot concatenate incompatible arrays")
2844            }
2845            EvalError::TypeFromOid(msg) => write!(f, "{msg}"),
2846            EvalError::InvalidRange(e) => e.fmt(f),
2847            EvalError::InvalidRoleId(msg) => write!(f, "{msg}"),
2848            EvalError::InvalidPrivileges(privilege) => {
2849                write!(f, "unrecognized privilege type: {privilege}")
2850            }
2851            EvalError::LetRecLimitExceeded(max_iters) => {
2852                write!(
2853                    f,
2854                    "Recursive query exceeded the recursion limit {}. (Use RETURN AT RECURSION LIMIT to not error, but return the current state as the final result when reaching the limit.)",
2855                    max_iters
2856                )
2857            }
2858            EvalError::MultiDimensionalArraySearch => write!(
2859                f,
2860                "searching for elements in multidimensional arrays is not supported"
2861            ),
2862            EvalError::MustNotBeNull(v) => write!(f, "{v} must not be null"),
2863            EvalError::InvalidIdentifier { ident, .. } => {
2864                write!(f, "string is not a valid identifier: {}", ident.quoted())
2865            }
2866            EvalError::ArrayFillWrongArraySubscripts => {
2867                f.write_str("wrong number of array subscripts")
2868            }
2869            EvalError::MaxArraySizeExceeded(max_size) => {
2870                write!(
2871                    f,
2872                    "array size exceeds the maximum allowed ({max_size} bytes)"
2873                )
2874            }
2875            EvalError::DateDiffOverflow { unit, a, b } => {
2876                write!(f, "datediff overflow, {unit} of {a}, {b}")
2877            }
2878            EvalError::IfNullError(s) => f.write_str(s),
2879            EvalError::LengthTooLarge => write!(f, "requested length too large"),
2880            EvalError::AclArrayNullElement => write!(f, "ACL arrays must not contain null values"),
2881            EvalError::MzAclArrayNullElement => {
2882                write!(f, "MZ_ACL arrays must not contain null values")
2883            }
2884        }
2885    }
2886}
2887
2888impl EvalError {
2889    pub fn detail(&self) -> Option<String> {
2890        match self {
2891            EvalError::IncompatibleArrayDimensions { dims: None } => Some(
2892                "Arrays with differing dimensions are not compatible for concatenation.".into(),
2893            ),
2894            EvalError::IncompatibleArrayDimensions {
2895                dims: Some((a_dims, b_dims)),
2896            } => Some(format!(
2897                "Arrays of {} and {} dimensions are not compatible for concatenation.",
2898                a_dims, b_dims
2899            )),
2900            EvalError::InvalidIdentifier { detail, .. } => detail.as_deref().map(Into::into),
2901            EvalError::ArrayFillWrongArraySubscripts => {
2902                Some("Low bound array has different size than dimensions array.".into())
2903            }
2904            _ => None,
2905        }
2906    }
2907
2908    pub fn hint(&self) -> Option<String> {
2909        match self {
2910            EvalError::InvalidBase64EndSequence => Some(
2911                "Input data is missing padding, is truncated, or is otherwise corrupted.".into(),
2912            ),
2913            EvalError::LikeEscapeTooLong => {
2914                Some("Escape string must be empty or one character.".into())
2915            }
2916            EvalError::MzTimestampOutOfRange(_) => Some(
2917                "Integer, numeric, and text casts to mz_timestamp must be in the form of whole \
2918                milliseconds since the Unix epoch. Values with fractional parts cannot be \
2919                converted to mz_timestamp."
2920                    .into(),
2921            ),
2922            _ => None,
2923        }
2924    }
2925}
2926
2927impl std::error::Error for EvalError {}
2928
2929impl From<ParseError> for EvalError {
2930    fn from(e: ParseError) -> EvalError {
2931        EvalError::Parse(e)
2932    }
2933}
2934
2935impl From<ParseHexError> for EvalError {
2936    fn from(e: ParseHexError) -> EvalError {
2937        EvalError::ParseHex(e)
2938    }
2939}
2940
2941impl From<InvalidArrayError> for EvalError {
2942    fn from(e: InvalidArrayError) -> EvalError {
2943        EvalError::InvalidArray(e)
2944    }
2945}
2946
2947impl From<RegexCompilationError> for EvalError {
2948    fn from(e: RegexCompilationError) -> EvalError {
2949        EvalError::InvalidRegex(e.to_string().into())
2950    }
2951}
2952
2953impl From<TypeFromOidError> for EvalError {
2954    fn from(e: TypeFromOidError) -> EvalError {
2955        EvalError::TypeFromOid(e.to_string().into())
2956    }
2957}
2958
2959impl From<DateError> for EvalError {
2960    fn from(e: DateError) -> EvalError {
2961        match e {
2962            DateError::OutOfRange => EvalError::DateOutOfRange,
2963        }
2964    }
2965}
2966
2967impl From<TimestampError> for EvalError {
2968    fn from(e: TimestampError) -> EvalError {
2969        match e {
2970            TimestampError::OutOfRange => EvalError::TimestampOutOfRange,
2971        }
2972    }
2973}
2974
2975impl From<InvalidRangeError> for EvalError {
2976    fn from(e: InvalidRangeError) -> EvalError {
2977        EvalError::InvalidRange(e)
2978    }
2979}
2980
2981impl RustType<ProtoEvalError> for EvalError {
2982    fn into_proto(&self) -> ProtoEvalError {
2983        use proto_eval_error::Kind::*;
2984        use proto_eval_error::*;
2985        let kind = match self {
2986            EvalError::CharacterNotValidForEncoding(v) => CharacterNotValidForEncoding(*v),
2987            EvalError::CharacterTooLargeForEncoding(v) => CharacterTooLargeForEncoding(*v),
2988            EvalError::DateBinOutOfRange(v) => DateBinOutOfRange(v.into_proto()),
2989            EvalError::DivisionByZero => DivisionByZero(()),
2990            EvalError::Unsupported {
2991                feature,
2992                discussion_no,
2993            } => Unsupported(ProtoUnsupported {
2994                feature: feature.into_proto(),
2995                discussion_no: discussion_no.into_proto(),
2996            }),
2997            EvalError::FloatOverflow => FloatOverflow(()),
2998            EvalError::FloatUnderflow => FloatUnderflow(()),
2999            EvalError::NumericFieldOverflow => NumericFieldOverflow(()),
3000            EvalError::Float32OutOfRange(val) => Float32OutOfRange(ProtoValueOutOfRange {
3001                value: val.to_string(),
3002            }),
3003            EvalError::Float64OutOfRange(val) => Float64OutOfRange(ProtoValueOutOfRange {
3004                value: val.to_string(),
3005            }),
3006            EvalError::Int16OutOfRange(val) => Int16OutOfRange(ProtoValueOutOfRange {
3007                value: val.to_string(),
3008            }),
3009            EvalError::Int32OutOfRange(val) => Int32OutOfRange(ProtoValueOutOfRange {
3010                value: val.to_string(),
3011            }),
3012            EvalError::Int64OutOfRange(val) => Int64OutOfRange(ProtoValueOutOfRange {
3013                value: val.to_string(),
3014            }),
3015            EvalError::UInt16OutOfRange(val) => Uint16OutOfRange(ProtoValueOutOfRange {
3016                value: val.to_string(),
3017            }),
3018            EvalError::UInt32OutOfRange(val) => Uint32OutOfRange(ProtoValueOutOfRange {
3019                value: val.to_string(),
3020            }),
3021            EvalError::UInt64OutOfRange(val) => Uint64OutOfRange(ProtoValueOutOfRange {
3022                value: val.to_string(),
3023            }),
3024            EvalError::MzTimestampOutOfRange(val) => MzTimestampOutOfRange(ProtoValueOutOfRange {
3025                value: val.to_string(),
3026            }),
3027            EvalError::MzTimestampStepOverflow => MzTimestampStepOverflow(()),
3028            EvalError::OidOutOfRange(val) => OidOutOfRange(ProtoValueOutOfRange {
3029                value: val.to_string(),
3030            }),
3031            EvalError::IntervalOutOfRange(val) => IntervalOutOfRange(ProtoValueOutOfRange {
3032                value: val.to_string(),
3033            }),
3034            EvalError::TimestampCannotBeNan => TimestampCannotBeNan(()),
3035            EvalError::TimestampOutOfRange => TimestampOutOfRange(()),
3036            EvalError::DateOutOfRange => DateOutOfRange(()),
3037            EvalError::CharOutOfRange => CharOutOfRange(()),
3038            EvalError::IndexOutOfRange {
3039                provided,
3040                valid_end,
3041            } => IndexOutOfRange(ProtoIndexOutOfRange {
3042                provided: *provided,
3043                valid_end: *valid_end,
3044            }),
3045            EvalError::InvalidBase64Equals => InvalidBase64Equals(()),
3046            EvalError::InvalidBase64Symbol(sym) => InvalidBase64Symbol(sym.into_proto()),
3047            EvalError::InvalidBase64EndSequence => InvalidBase64EndSequence(()),
3048            EvalError::InvalidTimezone(tz) => InvalidTimezone(tz.into_proto()),
3049            EvalError::InvalidTimezoneInterval => InvalidTimezoneInterval(()),
3050            EvalError::InvalidTimezoneConversion => InvalidTimezoneConversion(()),
3051            EvalError::InvalidLayer { max_layer, val } => InvalidLayer(ProtoInvalidLayer {
3052                max_layer: max_layer.into_proto(),
3053                val: *val,
3054            }),
3055            EvalError::InvalidArray(error) => InvalidArray(error.into_proto()),
3056            EvalError::InvalidEncodingName(v) => InvalidEncodingName(v.into_proto()),
3057            EvalError::InvalidHashAlgorithm(v) => InvalidHashAlgorithm(v.into_proto()),
3058            EvalError::InvalidByteSequence {
3059                byte_sequence,
3060                encoding_name,
3061            } => InvalidByteSequence(ProtoInvalidByteSequence {
3062                byte_sequence: byte_sequence.into_proto(),
3063                encoding_name: encoding_name.into_proto(),
3064            }),
3065            EvalError::InvalidJsonbCast { from, to } => InvalidJsonbCast(ProtoInvalidJsonbCast {
3066                from: from.into_proto(),
3067                to: to.into_proto(),
3068            }),
3069            EvalError::InvalidRegex(v) => InvalidRegex(v.into_proto()),
3070            EvalError::InvalidRegexFlag(v) => InvalidRegexFlag(v.into_proto()),
3071            EvalError::InvalidParameterValue(v) => InvalidParameterValue(v.into_proto()),
3072            EvalError::InvalidDatePart(part) => InvalidDatePart(part.into_proto()),
3073            EvalError::KeyCannotBeNull => KeyCannotBeNull(()),
3074            EvalError::NegSqrt => NegSqrt(()),
3075            EvalError::NegLimit => NegLimit(()),
3076            EvalError::NullCharacterNotPermitted => NullCharacterNotPermitted(()),
3077            EvalError::UnknownUnits(v) => UnknownUnits(v.into_proto()),
3078            EvalError::UnsupportedUnits(units, typ) => UnsupportedUnits(ProtoUnsupportedUnits {
3079                units: units.into_proto(),
3080                typ: typ.into_proto(),
3081            }),
3082            EvalError::UnterminatedLikeEscapeSequence => UnterminatedLikeEscapeSequence(()),
3083            EvalError::Parse(error) => Parse(error.into_proto()),
3084            EvalError::PrettyError(error) => PrettyError(error.into_proto()),
3085            EvalError::ParseHex(error) => ParseHex(error.into_proto()),
3086            EvalError::Internal(v) => Internal(v.into_proto()),
3087            EvalError::InfinityOutOfDomain(v) => InfinityOutOfDomain(v.into_proto()),
3088            EvalError::NegativeOutOfDomain(v) => NegativeOutOfDomain(v.into_proto()),
3089            EvalError::ZeroOutOfDomain(v) => ZeroOutOfDomain(v.into_proto()),
3090            EvalError::OutOfDomain(lower, upper, id) => OutOfDomain(ProtoOutOfDomain {
3091                lower: Some(lower.into_proto()),
3092                upper: Some(upper.into_proto()),
3093                id: id.into_proto(),
3094            }),
3095            EvalError::ComplexOutOfRange(v) => ComplexOutOfRange(v.into_proto()),
3096            EvalError::MultipleRowsFromSubquery => MultipleRowsFromSubquery(()),
3097            EvalError::Undefined(v) => Undefined(v.into_proto()),
3098            EvalError::LikePatternTooLong => LikePatternTooLong(()),
3099            EvalError::LikeEscapeTooLong => LikeEscapeTooLong(()),
3100            EvalError::StringValueTooLong {
3101                target_type,
3102                length,
3103            } => StringValueTooLong(ProtoStringValueTooLong {
3104                target_type: target_type.into_proto(),
3105                length: length.into_proto(),
3106            }),
3107            EvalError::MultidimensionalArrayRemovalNotSupported => {
3108                MultidimensionalArrayRemovalNotSupported(())
3109            }
3110            EvalError::IncompatibleArrayDimensions { dims } => {
3111                IncompatibleArrayDimensions(ProtoIncompatibleArrayDimensions {
3112                    dims: dims.into_proto(),
3113                })
3114            }
3115            EvalError::TypeFromOid(v) => TypeFromOid(v.into_proto()),
3116            EvalError::InvalidRange(error) => InvalidRange(error.into_proto()),
3117            EvalError::InvalidRoleId(v) => InvalidRoleId(v.into_proto()),
3118            EvalError::InvalidPrivileges(v) => InvalidPrivileges(v.into_proto()),
3119            EvalError::LetRecLimitExceeded(v) => WmrRecursionLimitExceeded(v.into_proto()),
3120            EvalError::MultiDimensionalArraySearch => MultiDimensionalArraySearch(()),
3121            EvalError::MustNotBeNull(v) => MustNotBeNull(v.into_proto()),
3122            EvalError::InvalidIdentifier { ident, detail } => {
3123                InvalidIdentifier(ProtoInvalidIdentifier {
3124                    ident: ident.into_proto(),
3125                    detail: detail.into_proto(),
3126                })
3127            }
3128            EvalError::ArrayFillWrongArraySubscripts => ArrayFillWrongArraySubscripts(()),
3129            EvalError::MaxArraySizeExceeded(max_size) => {
3130                MaxArraySizeExceeded(u64::cast_from(*max_size))
3131            }
3132            EvalError::DateDiffOverflow { unit, a, b } => DateDiffOverflow(ProtoDateDiffOverflow {
3133                unit: unit.into_proto(),
3134                a: a.into_proto(),
3135                b: b.into_proto(),
3136            }),
3137            EvalError::IfNullError(s) => IfNullError(s.into_proto()),
3138            EvalError::LengthTooLarge => LengthTooLarge(()),
3139            EvalError::AclArrayNullElement => AclArrayNullElement(()),
3140            EvalError::MzAclArrayNullElement => MzAclArrayNullElement(()),
3141            EvalError::InvalidIanaTimezoneId(s) => InvalidIanaTimezoneId(s.into_proto()),
3142        };
3143        ProtoEvalError { kind: Some(kind) }
3144    }
3145
3146    fn from_proto(proto: ProtoEvalError) -> Result<Self, TryFromProtoError> {
3147        use proto_eval_error::Kind::*;
3148        match proto.kind {
3149            Some(kind) => match kind {
3150                CharacterNotValidForEncoding(v) => Ok(EvalError::CharacterNotValidForEncoding(v)),
3151                CharacterTooLargeForEncoding(v) => Ok(EvalError::CharacterTooLargeForEncoding(v)),
3152                DateBinOutOfRange(v) => Ok(EvalError::DateBinOutOfRange(v.into())),
3153                DivisionByZero(()) => Ok(EvalError::DivisionByZero),
3154                Unsupported(v) => Ok(EvalError::Unsupported {
3155                    feature: v.feature.into(),
3156                    discussion_no: v.discussion_no.into_rust()?,
3157                }),
3158                FloatOverflow(()) => Ok(EvalError::FloatOverflow),
3159                FloatUnderflow(()) => Ok(EvalError::FloatUnderflow),
3160                NumericFieldOverflow(()) => Ok(EvalError::NumericFieldOverflow),
3161                Float32OutOfRange(val) => Ok(EvalError::Float32OutOfRange(val.value.into())),
3162                Float64OutOfRange(val) => Ok(EvalError::Float64OutOfRange(val.value.into())),
3163                Int16OutOfRange(val) => Ok(EvalError::Int16OutOfRange(val.value.into())),
3164                Int32OutOfRange(val) => Ok(EvalError::Int32OutOfRange(val.value.into())),
3165                Int64OutOfRange(val) => Ok(EvalError::Int64OutOfRange(val.value.into())),
3166                Uint16OutOfRange(val) => Ok(EvalError::UInt16OutOfRange(val.value.into())),
3167                Uint32OutOfRange(val) => Ok(EvalError::UInt32OutOfRange(val.value.into())),
3168                Uint64OutOfRange(val) => Ok(EvalError::UInt64OutOfRange(val.value.into())),
3169                MzTimestampOutOfRange(val) => {
3170                    Ok(EvalError::MzTimestampOutOfRange(val.value.into()))
3171                }
3172                MzTimestampStepOverflow(()) => Ok(EvalError::MzTimestampStepOverflow),
3173                OidOutOfRange(val) => Ok(EvalError::OidOutOfRange(val.value.into())),
3174                IntervalOutOfRange(val) => Ok(EvalError::IntervalOutOfRange(val.value.into())),
3175                TimestampCannotBeNan(()) => Ok(EvalError::TimestampCannotBeNan),
3176                TimestampOutOfRange(()) => Ok(EvalError::TimestampOutOfRange),
3177                DateOutOfRange(()) => Ok(EvalError::DateOutOfRange),
3178                CharOutOfRange(()) => Ok(EvalError::CharOutOfRange),
3179                IndexOutOfRange(v) => Ok(EvalError::IndexOutOfRange {
3180                    provided: v.provided,
3181                    valid_end: v.valid_end,
3182                }),
3183                InvalidBase64Equals(()) => Ok(EvalError::InvalidBase64Equals),
3184                InvalidBase64Symbol(v) => char::from_proto(v).map(EvalError::InvalidBase64Symbol),
3185                InvalidBase64EndSequence(()) => Ok(EvalError::InvalidBase64EndSequence),
3186                InvalidTimezone(v) => Ok(EvalError::InvalidTimezone(v.into())),
3187                InvalidTimezoneInterval(()) => Ok(EvalError::InvalidTimezoneInterval),
3188                InvalidTimezoneConversion(()) => Ok(EvalError::InvalidTimezoneConversion),
3189                InvalidLayer(v) => Ok(EvalError::InvalidLayer {
3190                    max_layer: usize::from_proto(v.max_layer)?,
3191                    val: v.val,
3192                }),
3193                InvalidArray(error) => Ok(EvalError::InvalidArray(error.into_rust()?)),
3194                InvalidEncodingName(v) => Ok(EvalError::InvalidEncodingName(v.into())),
3195                InvalidHashAlgorithm(v) => Ok(EvalError::InvalidHashAlgorithm(v.into())),
3196                InvalidByteSequence(v) => Ok(EvalError::InvalidByteSequence {
3197                    byte_sequence: v.byte_sequence.into(),
3198                    encoding_name: v.encoding_name.into(),
3199                }),
3200                InvalidJsonbCast(v) => Ok(EvalError::InvalidJsonbCast {
3201                    from: v.from.into(),
3202                    to: v.to.into(),
3203                }),
3204                InvalidRegex(v) => Ok(EvalError::InvalidRegex(v.into())),
3205                InvalidRegexFlag(v) => Ok(EvalError::InvalidRegexFlag(char::from_proto(v)?)),
3206                InvalidParameterValue(v) => Ok(EvalError::InvalidParameterValue(v.into())),
3207                InvalidDatePart(part) => Ok(EvalError::InvalidDatePart(part.into())),
3208                KeyCannotBeNull(()) => Ok(EvalError::KeyCannotBeNull),
3209                NegSqrt(()) => Ok(EvalError::NegSqrt),
3210                NegLimit(()) => Ok(EvalError::NegLimit),
3211                NullCharacterNotPermitted(()) => Ok(EvalError::NullCharacterNotPermitted),
3212                UnknownUnits(v) => Ok(EvalError::UnknownUnits(v.into())),
3213                UnsupportedUnits(v) => {
3214                    Ok(EvalError::UnsupportedUnits(v.units.into(), v.typ.into()))
3215                }
3216                UnterminatedLikeEscapeSequence(()) => Ok(EvalError::UnterminatedLikeEscapeSequence),
3217                Parse(error) => Ok(EvalError::Parse(error.into_rust()?)),
3218                ParseHex(error) => Ok(EvalError::ParseHex(error.into_rust()?)),
3219                Internal(v) => Ok(EvalError::Internal(v.into())),
3220                InfinityOutOfDomain(v) => Ok(EvalError::InfinityOutOfDomain(v.into())),
3221                NegativeOutOfDomain(v) => Ok(EvalError::NegativeOutOfDomain(v.into())),
3222                ZeroOutOfDomain(v) => Ok(EvalError::ZeroOutOfDomain(v.into())),
3223                OutOfDomain(v) => Ok(EvalError::OutOfDomain(
3224                    v.lower.into_rust_if_some("ProtoDomainLimit::lower")?,
3225                    v.upper.into_rust_if_some("ProtoDomainLimit::upper")?,
3226                    v.id.into(),
3227                )),
3228                ComplexOutOfRange(v) => Ok(EvalError::ComplexOutOfRange(v.into())),
3229                MultipleRowsFromSubquery(()) => Ok(EvalError::MultipleRowsFromSubquery),
3230                Undefined(v) => Ok(EvalError::Undefined(v.into())),
3231                LikePatternTooLong(()) => Ok(EvalError::LikePatternTooLong),
3232                LikeEscapeTooLong(()) => Ok(EvalError::LikeEscapeTooLong),
3233                StringValueTooLong(v) => Ok(EvalError::StringValueTooLong {
3234                    target_type: v.target_type.into(),
3235                    length: usize::from_proto(v.length)?,
3236                }),
3237                MultidimensionalArrayRemovalNotSupported(()) => {
3238                    Ok(EvalError::MultidimensionalArrayRemovalNotSupported)
3239                }
3240                IncompatibleArrayDimensions(v) => Ok(EvalError::IncompatibleArrayDimensions {
3241                    dims: v.dims.into_rust()?,
3242                }),
3243                TypeFromOid(v) => Ok(EvalError::TypeFromOid(v.into())),
3244                InvalidRange(e) => Ok(EvalError::InvalidRange(e.into_rust()?)),
3245                InvalidRoleId(v) => Ok(EvalError::InvalidRoleId(v.into())),
3246                InvalidPrivileges(v) => Ok(EvalError::InvalidPrivileges(v.into())),
3247                WmrRecursionLimitExceeded(v) => Ok(EvalError::LetRecLimitExceeded(v.into())),
3248                MultiDimensionalArraySearch(()) => Ok(EvalError::MultiDimensionalArraySearch),
3249                MustNotBeNull(v) => Ok(EvalError::MustNotBeNull(v.into())),
3250                InvalidIdentifier(v) => Ok(EvalError::InvalidIdentifier {
3251                    ident: v.ident.into(),
3252                    detail: v.detail.into_rust()?,
3253                }),
3254                ArrayFillWrongArraySubscripts(()) => Ok(EvalError::ArrayFillWrongArraySubscripts),
3255                MaxArraySizeExceeded(max_size) => {
3256                    Ok(EvalError::MaxArraySizeExceeded(usize::cast_from(max_size)))
3257                }
3258                DateDiffOverflow(v) => Ok(EvalError::DateDiffOverflow {
3259                    unit: v.unit.into(),
3260                    a: v.a.into(),
3261                    b: v.b.into(),
3262                }),
3263                IfNullError(v) => Ok(EvalError::IfNullError(v.into())),
3264                LengthTooLarge(()) => Ok(EvalError::LengthTooLarge),
3265                AclArrayNullElement(()) => Ok(EvalError::AclArrayNullElement),
3266                MzAclArrayNullElement(()) => Ok(EvalError::MzAclArrayNullElement),
3267                InvalidIanaTimezoneId(s) => Ok(EvalError::InvalidIanaTimezoneId(s.into())),
3268                PrettyError(s) => Ok(EvalError::PrettyError(s.into())),
3269            },
3270            None => Err(TryFromProtoError::missing_field("ProtoEvalError::kind")),
3271        }
3272    }
3273}
3274
3275impl RustType<ProtoDims> for (usize, usize) {
3276    fn into_proto(&self) -> ProtoDims {
3277        ProtoDims {
3278            f0: self.0.into_proto(),
3279            f1: self.1.into_proto(),
3280        }
3281    }
3282
3283    fn from_proto(proto: ProtoDims) -> Result<Self, TryFromProtoError> {
3284        Ok((proto.f0.into_rust()?, proto.f1.into_rust()?))
3285    }
3286}
3287
3288#[cfg(test)]
3289mod tests {
3290    use super::*;
3291
3292    #[mz_ore::test]
3293    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
3294    fn test_reduce() {
3295        let relation_type = vec![
3296            SqlScalarType::Int64.nullable(true),
3297            SqlScalarType::Int64.nullable(true),
3298            SqlScalarType::Int64.nullable(false),
3299        ];
3300        let col = MirScalarExpr::column;
3301        let err = |e| MirScalarExpr::literal(Err(e), SqlScalarType::Int64);
3302        let lit = |i| MirScalarExpr::literal_ok(Datum::Int64(i), SqlScalarType::Int64);
3303        let null = || MirScalarExpr::literal_null(SqlScalarType::Int64);
3304
3305        struct TestCase {
3306            input: MirScalarExpr,
3307            output: MirScalarExpr,
3308        }
3309
3310        let test_cases = vec![
3311            TestCase {
3312                input: MirScalarExpr::CallVariadic {
3313                    func: VariadicFunc::Coalesce,
3314                    exprs: vec![lit(1)],
3315                },
3316                output: lit(1),
3317            },
3318            TestCase {
3319                input: MirScalarExpr::CallVariadic {
3320                    func: VariadicFunc::Coalesce,
3321                    exprs: vec![lit(1), lit(2)],
3322                },
3323                output: lit(1),
3324            },
3325            TestCase {
3326                input: MirScalarExpr::CallVariadic {
3327                    func: VariadicFunc::Coalesce,
3328                    exprs: vec![null(), lit(2), null()],
3329                },
3330                output: lit(2),
3331            },
3332            TestCase {
3333                input: MirScalarExpr::CallVariadic {
3334                    func: VariadicFunc::Coalesce,
3335                    exprs: vec![null(), col(0), null(), col(1), lit(2), lit(3)],
3336                },
3337                output: MirScalarExpr::CallVariadic {
3338                    func: VariadicFunc::Coalesce,
3339                    exprs: vec![col(0), col(1), lit(2)],
3340                },
3341            },
3342            TestCase {
3343                input: MirScalarExpr::CallVariadic {
3344                    func: VariadicFunc::Coalesce,
3345                    exprs: vec![col(0), col(2), col(1)],
3346                },
3347                output: MirScalarExpr::CallVariadic {
3348                    func: VariadicFunc::Coalesce,
3349                    exprs: vec![col(0), col(2)],
3350                },
3351            },
3352            TestCase {
3353                input: MirScalarExpr::CallVariadic {
3354                    func: VariadicFunc::Coalesce,
3355                    exprs: vec![lit(1), err(EvalError::DivisionByZero)],
3356                },
3357                output: lit(1),
3358            },
3359            TestCase {
3360                input: MirScalarExpr::CallVariadic {
3361                    func: VariadicFunc::Coalesce,
3362                    exprs: vec![
3363                        null(),
3364                        err(EvalError::DivisionByZero),
3365                        err(EvalError::NumericFieldOverflow),
3366                    ],
3367                },
3368                output: err(EvalError::DivisionByZero),
3369            },
3370        ];
3371
3372        for tc in test_cases {
3373            let mut actual = tc.input.clone();
3374            actual.reduce(&relation_type);
3375            assert!(
3376                actual == tc.output,
3377                "input: {}\nactual: {}\nexpected: {}",
3378                tc.input,
3379                actual,
3380                tc.output
3381            );
3382        }
3383    }
3384}