mz_expr/
scalar.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::ops::BitOrAssign;
12use std::sync::Arc;
13use std::{fmt, mem};
14
15use itertools::Itertools;
16use mz_lowertest::MzReflect;
17use mz_ore::cast::CastFrom;
18use mz_ore::collections::CollectionExt;
19use mz_ore::iter::IteratorExt;
20use mz_ore::stack::RecursionLimitError;
21use mz_ore::str::StrExt;
22use mz_ore::treat_as_equal::TreatAsEqual;
23use mz_ore::vec::swap_remove_multiple;
24use mz_pgrepr::TypeFromOidError;
25use mz_pgtz::timezone::TimezoneSpec;
26use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
27use mz_repr::adt::array::InvalidArrayError;
28use mz_repr::adt::date::DateError;
29use mz_repr::adt::datetime::DateTimeUnits;
30use mz_repr::adt::range::InvalidRangeError;
31use mz_repr::adt::regex::Regex;
32use mz_repr::adt::timestamp::TimestampError;
33use mz_repr::strconv::{ParseError, ParseHexError};
34use mz_repr::{ColumnType, Datum, Row, RowArena, ScalarType, arb_datum};
35use proptest::prelude::*;
36use proptest_derive::Arbitrary;
37use serde::{Deserialize, Serialize};
38
39use crate::scalar::func::format::DateTimeFormat;
40use crate::scalar::func::{
41    BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, parse_timezone,
42    regexp_replace_parse_flags,
43};
44use crate::scalar::proto_eval_error::proto_incompatible_array_dimensions::ProtoDims;
45use crate::scalar::proto_mir_scalar_expr::*;
46use crate::visit::{Visit, VisitChildren};
47
48pub mod func;
49pub mod like_pattern;
50
51include!(concat!(env!("OUT_DIR"), "/mz_expr.scalar.rs"));
52
53#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, MzReflect)]
54pub enum MirScalarExpr {
55    /// A column of the input row
56    Column(usize, TreatAsEqual<Option<Arc<str>>>),
57    /// A literal value.
58    /// (Stored as a row, because we can't own a Datum)
59    Literal(Result<Row, EvalError>, ColumnType),
60    /// A call to an unmaterializable function.
61    ///
62    /// These functions cannot be evaluated by `MirScalarExpr::eval`. They must
63    /// be transformed away by a higher layer.
64    CallUnmaterializable(UnmaterializableFunc),
65    /// A function call that takes one expression as an argument.
66    CallUnary {
67        func: UnaryFunc,
68        expr: Box<MirScalarExpr>,
69    },
70    /// A function call that takes two expressions as arguments.
71    CallBinary {
72        func: BinaryFunc,
73        expr1: Box<MirScalarExpr>,
74        expr2: Box<MirScalarExpr>,
75    },
76    /// A function call that takes an arbitrary number of arguments.
77    CallVariadic {
78        func: VariadicFunc,
79        exprs: Vec<MirScalarExpr>,
80    },
81    /// Conditionally evaluated expressions.
82    ///
83    /// It is important that `then` and `els` only be evaluated if
84    /// `cond` is true or not, respectively. This is the only way
85    /// users can guard execution (other logical operator do not
86    /// short-circuit) and we need to preserve that.
87    If {
88        cond: Box<MirScalarExpr>,
89        then: Box<MirScalarExpr>,
90        els: Box<MirScalarExpr>,
91    },
92}
93
94// We need a custom Debug because we don't want to show `None` for name information.
95// Sadly, the `derivative` crate doesn't support this use case.
96impl std::fmt::Debug for MirScalarExpr {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        match self {
99            MirScalarExpr::Column(i, TreatAsEqual(Some(name))) => {
100                write!(f, "Column({i}, {name:?})")
101            }
102            MirScalarExpr::Column(i, TreatAsEqual(None)) => write!(f, "Column({i})"),
103            MirScalarExpr::Literal(lit, typ) => write!(f, "Literal({lit:?}, {typ:?})"),
104            MirScalarExpr::CallUnmaterializable(func) => {
105                write!(f, "CallUnmaterializable({func:?})")
106            }
107            MirScalarExpr::CallUnary { func, expr } => {
108                write!(f, "CallUnary({func:?}, {expr:?})")
109            }
110            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
111                write!(f, "CallBinary({func:?}, {expr1:?}, {expr2:?})")
112            }
113            MirScalarExpr::CallVariadic { func, exprs } => {
114                write!(f, "CallVariadic({func:?}, {exprs:?})")
115            }
116            MirScalarExpr::If { cond, then, els } => {
117                write!(f, "If({cond:?}, {then:?}, {els:?})")
118            }
119        }
120    }
121}
122
123impl Arbitrary for MirScalarExpr {
124    type Parameters = ();
125    type Strategy = BoxedStrategy<MirScalarExpr>;
126
127    fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
128        let leaf = prop::strategy::Union::new(vec![
129            (0..10_usize).prop_map(MirScalarExpr::column).boxed(),
130            (arb_datum(), any::<ScalarType>())
131                .prop_map(|(datum, typ)| MirScalarExpr::literal(Ok((&datum).into()), typ))
132                .boxed(),
133            (any::<EvalError>(), any::<ScalarType>())
134                .prop_map(|(err, typ)| MirScalarExpr::literal(Err(err), typ))
135                .boxed(),
136            any::<UnmaterializableFunc>()
137                .prop_map(MirScalarExpr::CallUnmaterializable)
138                .boxed(),
139        ]);
140        leaf.prop_recursive(3, 6, 7, |inner| {
141            prop::strategy::Union::new(vec![
142                (
143                    any::<VariadicFunc>(),
144                    prop::collection::vec(inner.clone(), 1..5),
145                )
146                    .prop_map(|(func, exprs)| MirScalarExpr::CallVariadic { func, exprs })
147                    .boxed(),
148                (any::<BinaryFunc>(), inner.clone(), inner.clone())
149                    .prop_map(|(func, expr1, expr2)| MirScalarExpr::CallBinary {
150                        func,
151                        expr1: Box::new(expr1),
152                        expr2: Box::new(expr2),
153                    })
154                    .boxed(),
155                (inner.clone(), inner.clone(), inner.clone())
156                    .prop_map(|(cond, then, els)| MirScalarExpr::If {
157                        cond: Box::new(cond),
158                        then: Box::new(then),
159                        els: Box::new(els),
160                    })
161                    .boxed(),
162                (any::<UnaryFunc>(), inner)
163                    .prop_map(|(func, expr)| MirScalarExpr::CallUnary {
164                        func,
165                        expr: Box::new(expr),
166                    })
167                    .boxed(),
168            ])
169        })
170        .boxed()
171    }
172}
173
174impl RustType<ProtoMirScalarExpr> for MirScalarExpr {
175    fn into_proto(&self) -> ProtoMirScalarExpr {
176        use proto_mir_scalar_expr::Kind::*;
177        ProtoMirScalarExpr {
178            kind: Some(match self {
179                MirScalarExpr::Column(index, name) => Column(ProtoColumn {
180                    index: index.into_proto(),
181                    name: name.0.as_ref().map(ToString::to_string),
182                }),
183                MirScalarExpr::Literal(lit, typ) => Literal(ProtoLiteral {
184                    lit: Some(lit.into_proto()),
185                    typ: Some(typ.into_proto()),
186                }),
187                MirScalarExpr::CallUnmaterializable(func) => {
188                    CallUnmaterializable(func.into_proto())
189                }
190                MirScalarExpr::CallUnary { func, expr } => CallUnary(Box::new(ProtoCallUnary {
191                    func: Some(Box::new(func.into_proto())),
192                    expr: Some(expr.into_proto()),
193                })),
194                MirScalarExpr::CallBinary { func, expr1, expr2 } => {
195                    CallBinary(Box::new(ProtoCallBinary {
196                        func: Some(func.into_proto()),
197                        expr1: Some(expr1.into_proto()),
198                        expr2: Some(expr2.into_proto()),
199                    }))
200                }
201                MirScalarExpr::CallVariadic { func, exprs } => CallVariadic(ProtoCallVariadic {
202                    func: Some(func.into_proto()),
203                    exprs: exprs.into_proto(),
204                }),
205                MirScalarExpr::If { cond, then, els } => If(Box::new(ProtoIf {
206                    cond: Some(cond.into_proto()),
207                    then: Some(then.into_proto()),
208                    els: Some(els.into_proto()),
209                })),
210            }),
211        }
212    }
213
214    fn from_proto(proto: ProtoMirScalarExpr) -> Result<Self, TryFromProtoError> {
215        use proto_mir_scalar_expr::Kind::*;
216        let kind = proto
217            .kind
218            .ok_or_else(|| TryFromProtoError::missing_field("ProtoMirScalarExpr::kind"))?;
219        Ok(match kind {
220            Column(ProtoColumn { index, name }) => {
221                let index = usize::from_proto(index)?;
222                match name {
223                    Some(name) => MirScalarExpr::named_column(index, name.into()),
224                    None => MirScalarExpr::column(index),
225                }
226            }
227            Literal(ProtoLiteral { lit, typ }) => MirScalarExpr::Literal(
228                lit.into_rust_if_some("ProtoLiteral::lit")?,
229                typ.into_rust_if_some("ProtoLiteral::typ")?,
230            ),
231            CallUnmaterializable(func) => MirScalarExpr::CallUnmaterializable(func.into_rust()?),
232            CallUnary(call_unary) => MirScalarExpr::CallUnary {
233                func: call_unary.func.into_rust_if_some("ProtoCallUnary::func")?,
234                expr: call_unary.expr.into_rust_if_some("ProtoCallUnary::expr")?,
235            },
236            CallBinary(call_binary) => MirScalarExpr::CallBinary {
237                func: call_binary
238                    .func
239                    .into_rust_if_some("ProtoCallBinary::func")?,
240                expr1: call_binary
241                    .expr1
242                    .into_rust_if_some("ProtoCallBinary::expr1")?,
243                expr2: call_binary
244                    .expr2
245                    .into_rust_if_some("ProtoCallBinary::expr2")?,
246            },
247            CallVariadic(ProtoCallVariadic { func, exprs }) => MirScalarExpr::CallVariadic {
248                func: func.into_rust_if_some("ProtoCallVariadic::func")?,
249                exprs: exprs.into_rust()?,
250            },
251            If(if_struct) => MirScalarExpr::If {
252                cond: if_struct.cond.into_rust_if_some("ProtoIf::cond")?,
253                then: if_struct.then.into_rust_if_some("ProtoIf::then")?,
254                els: if_struct.els.into_rust_if_some("ProtoIf::els")?,
255            },
256        })
257    }
258}
259
260impl RustType<proto_literal::ProtoLiteralData> for Result<Row, EvalError> {
261    fn into_proto(&self) -> proto_literal::ProtoLiteralData {
262        use proto_literal::proto_literal_data::Result::*;
263        proto_literal::ProtoLiteralData {
264            result: Some(match self {
265                Result::Ok(row) => Ok(row.into_proto()),
266                Result::Err(err) => Err(err.into_proto()),
267            }),
268        }
269    }
270
271    fn from_proto(proto: proto_literal::ProtoLiteralData) -> Result<Self, TryFromProtoError> {
272        use proto_literal::proto_literal_data::Result::*;
273        match proto.result {
274            Some(Ok(row)) => Result::Ok(Result::Ok(
275                (&row)
276                    .try_into()
277                    .map_err(TryFromProtoError::RowConversionError)?,
278            )),
279            Some(Err(err)) => Result::Ok(Result::Err(err.into_rust()?)),
280            None => Result::Err(TryFromProtoError::missing_field("ProtoLiteralData::result")),
281        }
282    }
283}
284
285impl MirScalarExpr {
286    pub fn columns(is: &[usize]) -> Vec<MirScalarExpr> {
287        is.iter().map(|i| MirScalarExpr::column(*i)).collect()
288    }
289
290    pub fn column(column: usize) -> Self {
291        MirScalarExpr::Column(column, TreatAsEqual(None))
292    }
293
294    pub fn named_column(column: usize, name: Arc<str>) -> Self {
295        MirScalarExpr::Column(column, TreatAsEqual(Some(name)))
296    }
297
298    pub fn literal(res: Result<Datum, EvalError>, typ: ScalarType) -> Self {
299        let typ = typ.nullable(matches!(res, Ok(Datum::Null)));
300        let row = res.map(|datum| Row::pack_slice(&[datum]));
301        MirScalarExpr::Literal(row, typ)
302    }
303
304    pub fn literal_ok(datum: Datum, typ: ScalarType) -> Self {
305        MirScalarExpr::literal(Ok(datum), typ)
306    }
307
308    pub fn literal_null(typ: ScalarType) -> Self {
309        MirScalarExpr::literal_ok(Datum::Null, typ)
310    }
311
312    pub fn literal_false() -> Self {
313        MirScalarExpr::literal_ok(Datum::False, ScalarType::Bool)
314    }
315
316    pub fn literal_true() -> Self {
317        MirScalarExpr::literal_ok(Datum::True, ScalarType::Bool)
318    }
319
320    pub fn call_unary(self, func: UnaryFunc) -> Self {
321        MirScalarExpr::CallUnary {
322            func,
323            expr: Box::new(self),
324        }
325    }
326
327    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
328        MirScalarExpr::CallBinary {
329            func,
330            expr1: Box::new(self),
331            expr2: Box::new(other),
332        }
333    }
334
335    pub fn if_then_else(self, t: Self, f: Self) -> Self {
336        MirScalarExpr::If {
337            cond: Box::new(self),
338            then: Box::new(t),
339            els: Box::new(f),
340        }
341    }
342
343    pub fn or(self, other: Self) -> Self {
344        MirScalarExpr::CallVariadic {
345            func: VariadicFunc::Or,
346            exprs: vec![self, other],
347        }
348    }
349
350    pub fn and(self, other: Self) -> Self {
351        MirScalarExpr::CallVariadic {
352            func: VariadicFunc::And,
353            exprs: vec![self, other],
354        }
355    }
356
357    pub fn not(self) -> Self {
358        self.call_unary(UnaryFunc::Not(func::Not))
359    }
360
361    pub fn call_is_null(self) -> Self {
362        self.call_unary(UnaryFunc::IsNull(func::IsNull))
363    }
364
365    /// Match AND or OR on self and get the args. If no match, then interpret self as if it were
366    /// wrapped in a 1-arg AND/OR.
367    pub fn and_or_args(&self, func_to_match: VariadicFunc) -> Vec<MirScalarExpr> {
368        assert!(func_to_match == VariadicFunc::Or || func_to_match == VariadicFunc::And);
369        match self {
370            MirScalarExpr::CallVariadic { func, exprs } if *func == func_to_match => exprs.clone(),
371            _ => vec![self.clone()],
372        }
373    }
374
375    /// Try to match a literal equality involving the given expression on one side.
376    /// Return the (non-null) literal and a bool that indicates whether an inversion was needed.
377    ///
378    /// More specifically:
379    /// If `self` is an equality with a `null` literal on any side, then the match fails!
380    /// Otherwise: for a given `expr`, if `self` is `<expr> = <literal>` or `<literal> = <expr>`
381    /// then return `Some((<literal>, false))`. In addition to just trying to match `<expr>` as it
382    /// is, we also try to remove an invertible function call (such as a cast). If the match
383    /// succeeds with the inversion, then return `Some((<inverted-literal>, true))`. For more
384    /// details on the inversion, see `invert_casts_on_expr_eq_literal_inner`.
385    pub fn expr_eq_literal(&self, expr: &MirScalarExpr) -> Option<(Row, bool)> {
386        if let MirScalarExpr::CallBinary {
387            func: BinaryFunc::Eq,
388            expr1,
389            expr2,
390        } = self
391        {
392            if expr1.is_literal_null() || expr2.is_literal_null() {
393                return None;
394            }
395            if let Some(Ok(lit)) = expr1.as_literal_owned() {
396                return Self::expr_eq_literal_inner(expr, lit, expr1, expr2);
397            }
398            if let Some(Ok(lit)) = expr2.as_literal_owned() {
399                return Self::expr_eq_literal_inner(expr, lit, expr2, expr1);
400            }
401        }
402        None
403    }
404
405    fn expr_eq_literal_inner(
406        expr_to_match: &MirScalarExpr,
407        literal: Row,
408        literal_expr: &MirScalarExpr,
409        other_side: &MirScalarExpr,
410    ) -> Option<(Row, bool)> {
411        if other_side == expr_to_match {
412            return Some((literal, false));
413        } else {
414            // expr didn't exactly match. See if we can match it by inverse-casting.
415            let (cast_removed, inv_cast_lit) =
416                Self::invert_casts_on_expr_eq_literal_inner(other_side, literal_expr);
417            if &cast_removed == expr_to_match {
418                if let Some(Ok(inv_cast_lit_row)) = inv_cast_lit.as_literal_owned() {
419                    return Some((inv_cast_lit_row, true));
420                }
421            }
422        }
423        None
424    }
425
426    /// If `self` is `<expr> = <literal>` or `<literal> = <expr>` then
427    /// return `<expr>`. It also tries to remove a cast (or other invertible function call) from
428    /// `<expr>` before returning it, see `invert_casts_on_expr_eq_literal_inner`.
429    pub fn any_expr_eq_literal(&self) -> Option<MirScalarExpr> {
430        if let MirScalarExpr::CallBinary {
431            func: BinaryFunc::Eq,
432            expr1,
433            expr2,
434        } = self
435        {
436            if expr1.is_literal() {
437                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
438                return Some(expr);
439            }
440            if expr2.is_literal() {
441                let (expr, _literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
442                return Some(expr);
443            }
444        }
445        None
446    }
447
448    /// If the given `MirScalarExpr` is a literal equality where one side is an invertible function
449    /// call, then calls the inverse function on both sides of the equality and returns the modified
450    /// version of the given `MirScalarExpr`. Otherwise, it returns the original expression.
451    /// For more details, see `invert_casts_on_expr_eq_literal_inner`.
452    pub fn invert_casts_on_expr_eq_literal(&self) -> MirScalarExpr {
453        if let MirScalarExpr::CallBinary {
454            func: BinaryFunc::Eq,
455            expr1,
456            expr2,
457        } = self
458        {
459            if expr1.is_literal() {
460                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr2, expr1);
461                return MirScalarExpr::CallBinary {
462                    func: BinaryFunc::Eq,
463                    expr1: Box::new(literal),
464                    expr2: Box::new(expr),
465                };
466            }
467            if expr2.is_literal() {
468                let (expr, literal) = Self::invert_casts_on_expr_eq_literal_inner(expr1, expr2);
469                return MirScalarExpr::CallBinary {
470                    func: BinaryFunc::Eq,
471                    expr1: Box::new(literal),
472                    expr2: Box::new(expr),
473                };
474            }
475            // Note: The above return statements should be consistent in whether they put the
476            // literal in expr1 or expr2, for the deduplication in CanonicalizeMfp to work.
477        }
478        self.clone()
479    }
480
481    /// Given an `<expr>` and a `<literal>` that were taken out from `<expr> = <literal>` or
482    /// `<literal> = <expr>`, it tries to simplify the equality by applying the inverse function of
483    /// the outermost function call of `<expr>` (if exists):
484    ///
485    /// `<literal> = func(<inner_expr>)`, where `func` is invertible
486    ///  -->
487    /// `<func^-1(literal)> = <inner_expr>`
488    /// if `func^-1(literal)` doesn't error out, and both `func` and `func^-1` preserve uniqueness.
489    ///
490    /// The return value is the `<inner_expr>` and the literal value that we get by applying the
491    /// inverse function.
492    fn invert_casts_on_expr_eq_literal_inner(
493        expr: &MirScalarExpr,
494        literal: &MirScalarExpr,
495    ) -> (MirScalarExpr, MirScalarExpr) {
496        assert!(matches!(literal, MirScalarExpr::Literal(..)));
497
498        let temp_storage = &RowArena::new();
499        let eval = |e: &MirScalarExpr| {
500            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
501        };
502
503        if let MirScalarExpr::CallUnary {
504            func,
505            expr: inner_expr,
506        } = expr
507        {
508            if let Some(inverse_func) = func.inverse() {
509                // We don't want to remove a function call that doesn't preserve uniqueness, e.g.,
510                // if `f` is a float, we don't want to inverse-cast `f::INT = 0`, because the
511                // inserted int-to-float cast wouldn't be able to invert the rounding.
512                // Also, we don't want to insert a function call that doesn't preserve
513                // uniqueness. E.g., if `a` has an integer type, we don't want to do
514                // a surprise rounding for `WHERE a = 3.14`.
515                if func.preserves_uniqueness() && inverse_func.preserves_uniqueness() {
516                    let lit_inv = eval(&MirScalarExpr::CallUnary {
517                        func: inverse_func,
518                        expr: Box::new(literal.clone()),
519                    });
520                    // The evaluation can error out, e.g., when casting a too large int32 to int16.
521                    // This case is handled by `impossible_literal_equality_because_types`.
522                    if !lit_inv.is_literal_err() {
523                        return (*inner_expr.clone(), lit_inv);
524                    }
525                }
526            }
527        }
528        (expr.clone(), literal.clone())
529    }
530
531    /// Tries to remove a cast (or other invertible function) in the same way as
532    /// `invert_casts_on_expr_eq_literal`, but if calling the inverse function fails on the literal,
533    /// then it deems the equality to be impossible. For example if `a` is a smallint column, then
534    /// it catches `a::integer = 1000000` to be an always false predicate (where the `::integer`
535    /// could have been inserted implicitly).
536    pub fn impossible_literal_equality_because_types(&self) -> bool {
537        if let MirScalarExpr::CallBinary {
538            func: BinaryFunc::Eq,
539            expr1,
540            expr2,
541        } = self
542        {
543            if expr1.is_literal() {
544                return Self::impossible_literal_equality_because_types_inner(expr1, expr2);
545            }
546            if expr2.is_literal() {
547                return Self::impossible_literal_equality_because_types_inner(expr2, expr1);
548            }
549        }
550        false
551    }
552
553    fn impossible_literal_equality_because_types_inner(
554        literal: &MirScalarExpr,
555        other_side: &MirScalarExpr,
556    ) -> bool {
557        assert!(matches!(literal, MirScalarExpr::Literal(..)));
558
559        let temp_storage = &RowArena::new();
560        let eval = |e: &MirScalarExpr| {
561            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(&[]).scalar_type)
562        };
563
564        if let MirScalarExpr::CallUnary { func, .. } = other_side {
565            if let Some(inverse_func) = func.inverse() {
566                if inverse_func.preserves_uniqueness()
567                    && eval(&MirScalarExpr::CallUnary {
568                        func: inverse_func,
569                        expr: Box::new(literal.clone()),
570                    })
571                    .is_literal_err()
572                {
573                    return true;
574                }
575            }
576        }
577
578        false
579    }
580
581    /// Determines if `self` is
582    /// `<expr> < <literal>` or
583    /// `<expr> > <literal>` or
584    /// `<literal> < <expr>` or
585    /// `<literal> > <expr>` or
586    /// `<expr> <= <literal>` or
587    /// `<expr> >= <literal>` or
588    /// `<literal> <= <expr>` or
589    /// `<literal> >= <expr>`.
590    pub fn any_expr_ineq_literal(&self) -> bool {
591        match self {
592            MirScalarExpr::CallBinary {
593                func: BinaryFunc::Lt | BinaryFunc::Lte | BinaryFunc::Gt | BinaryFunc::Gte,
594                expr1,
595                expr2,
596            } => expr1.is_literal() || expr2.is_literal(),
597            _ => false,
598        }
599    }
600
601    /// Rewrites column indices with their value in `permutation`.
602    ///
603    /// This method is applicable even when `permutation` is not a
604    /// strict permutation, and it only needs to have entries for
605    /// each column referenced in `self`.
606    pub fn permute(&mut self, permutation: &[usize]) {
607        self.visit_columns(|c| *c = permutation[*c]);
608    }
609
610    /// Rewrites column indices with their value in `permutation`.
611    ///
612    /// This method is applicable even when `permutation` is not a
613    /// strict permutation, and it only needs to have entries for
614    /// each column referenced in `self`.
615    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
616        self.visit_columns(|c| *c = permutation[c]);
617    }
618
619    /// Visits each column reference and applies `action` to the column.
620    ///
621    /// Useful for remapping columns, or for collecting expression support.
622    pub fn visit_columns<F>(&mut self, mut action: F)
623    where
624        F: FnMut(&mut usize),
625    {
626        self.visit_pre_mut(|e| {
627            if let MirScalarExpr::Column(col, _) = e {
628                action(col);
629            }
630        });
631    }
632
633    pub fn support(&self) -> BTreeSet<usize> {
634        let mut support = BTreeSet::new();
635        self.support_into(&mut support);
636        support
637    }
638
639    pub fn support_into(&self, support: &mut BTreeSet<usize>) {
640        self.visit_pre(|e| {
641            if let MirScalarExpr::Column(i, _) = e {
642                support.insert(*i);
643            }
644        });
645    }
646
647    pub fn take(&mut self) -> Self {
648        mem::replace(self, MirScalarExpr::literal_null(ScalarType::String))
649    }
650
651    pub fn as_literal(&self) -> Option<Result<Datum, &EvalError>> {
652        if let MirScalarExpr::Literal(lit, _column_type) = self {
653            Some(lit.as_ref().map(|row| row.unpack_first()))
654        } else {
655            None
656        }
657    }
658
659    pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
660        if let MirScalarExpr::Literal(lit, _column_type) = self {
661            Some(lit.clone())
662        } else {
663            None
664        }
665    }
666
667    pub fn as_literal_str(&self) -> Option<&str> {
668        match self.as_literal() {
669            Some(Ok(Datum::String(s))) => Some(s),
670            _ => None,
671        }
672    }
673
674    pub fn as_literal_int64(&self) -> Option<i64> {
675        match self.as_literal() {
676            Some(Ok(Datum::Int64(i))) => Some(i),
677            _ => None,
678        }
679    }
680
681    pub fn as_literal_err(&self) -> Option<&EvalError> {
682        self.as_literal().and_then(|lit| lit.err())
683    }
684
685    pub fn is_literal(&self) -> bool {
686        matches!(self, MirScalarExpr::Literal(_, _))
687    }
688
689    pub fn is_literal_true(&self) -> bool {
690        Some(Ok(Datum::True)) == self.as_literal()
691    }
692
693    pub fn is_literal_false(&self) -> bool {
694        Some(Ok(Datum::False)) == self.as_literal()
695    }
696
697    pub fn is_literal_null(&self) -> bool {
698        Some(Ok(Datum::Null)) == self.as_literal()
699    }
700
701    pub fn is_literal_ok(&self) -> bool {
702        matches!(self, MirScalarExpr::Literal(Ok(_), _typ))
703    }
704
705    pub fn is_literal_err(&self) -> bool {
706        matches!(self, MirScalarExpr::Literal(Err(_), _typ))
707    }
708
709    pub fn is_column(&self) -> bool {
710        matches!(self, MirScalarExpr::Column(_col, _name))
711    }
712
713    pub fn is_error_if_null(&self) -> bool {
714        matches!(
715            self,
716            Self::CallVariadic {
717                func: VariadicFunc::ErrorIfNull,
718                ..
719            }
720        )
721    }
722
723    #[deprecated = "Use `might_error` instead"]
724    pub fn contains_error_if_null(&self) -> bool {
725        let mut worklist = vec![self];
726        while let Some(expr) = worklist.pop() {
727            if expr.is_error_if_null() {
728                return true;
729            }
730            worklist.extend(expr.children());
731        }
732        false
733    }
734
735    /// A very crude approximation for scalar expressions that might produce an
736    /// error.
737    ///
738    /// Currently, this is restricted only to expressions that either contain a
739    /// literal error or a [`VariadicFunc::ErrorIfNull`] call.
740    pub fn might_error(&self) -> bool {
741        let mut worklist = vec![self];
742        while let Some(expr) = worklist.pop() {
743            if expr.is_literal_err() || expr.is_error_if_null() {
744                return true;
745            }
746            worklist.extend(expr.children());
747        }
748        false
749    }
750
751    /// If self is a column, return the column index, otherwise `None`.
752    pub fn as_column(&self) -> Option<usize> {
753        if let MirScalarExpr::Column(c, _) = self {
754            Some(*c)
755        } else {
756            None
757        }
758    }
759
760    /// Reduces a complex expression where possible.
761    ///
762    /// This function uses nullability information present in `column_types`,
763    /// and the result may only continue to be a correct transformation as
764    /// long as this information continues to hold (nullability may not hold
765    /// as expressions migrate around).
766    ///
767    /// (If you'd like to not use nullability information here, then you can
768    /// tweak the nullabilities in `column_types` before passing it to this
769    /// function, see e.g. in `EquivalenceClasses::minimize`.)
770    ///
771    /// Also performs partial canonicalization on the expression.
772    ///
773    /// ```rust
774    /// use mz_expr::MirScalarExpr;
775    /// use mz_repr::{ColumnType, Datum, ScalarType};
776    ///
777    /// let expr_0 = MirScalarExpr::column(0);
778    /// let expr_t = MirScalarExpr::literal_true();
779    /// let expr_f = MirScalarExpr::literal_false();
780    ///
781    /// let mut test =
782    /// expr_t
783    ///     .clone()
784    ///     .and(expr_f.clone())
785    ///     .if_then_else(expr_0, expr_t.clone());
786    ///
787    /// let input_type = vec![ScalarType::Int32.nullable(false)];
788    /// test.reduce(&input_type);
789    /// assert_eq!(test, expr_t);
790    /// ```
791    pub fn reduce(&mut self, column_types: &[ColumnType]) {
792        let temp_storage = &RowArena::new();
793        let eval = |e: &MirScalarExpr| {
794            MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type)
795        };
796
797        // Simplifications run in a loop until `self` no longer changes.
798        let mut old_self = MirScalarExpr::column(0);
799        while old_self != *self {
800            old_self = self.clone();
801            #[allow(deprecated)]
802            self.visit_mut_pre_post_nolimit(
803                &mut |e| {
804                    match e {
805                        MirScalarExpr::CallUnary { func, expr } => {
806                            if *func == UnaryFunc::IsNull(func::IsNull) {
807                                if !expr.typ(column_types).nullable {
808                                    *e = MirScalarExpr::literal_false();
809                                } else {
810                                    // Try to at least decompose IsNull into a disjunction
811                                    // of simpler IsNull subexpressions.
812                                    if let Some(expr) = expr.decompose_is_null() {
813                                        *e = expr
814                                    }
815                                }
816                            } else if *func == UnaryFunc::Not(func::Not) {
817                                // Push down not expressions
818                                match &mut **expr {
819                                    // Two negates cancel each other out.
820                                    MirScalarExpr::CallUnary {
821                                        expr: inner_expr,
822                                        func: UnaryFunc::Not(func::Not),
823                                    } => *e = inner_expr.take(),
824                                    // Transforms `NOT(a <op> b)` to `a negate(<op>) b`
825                                    // if a negation exists.
826                                    MirScalarExpr::CallBinary { expr1, expr2, func } => {
827                                        if let Some(negated_func) = func.negate() {
828                                            *e = MirScalarExpr::CallBinary {
829                                                expr1: Box::new(expr1.take()),
830                                                expr2: Box::new(expr2.take()),
831                                                func: negated_func,
832                                            }
833                                        }
834                                    }
835                                    MirScalarExpr::CallVariadic { .. } => {
836                                        e.demorgans();
837                                    }
838                                    _ => {}
839                                }
840                            }
841                        }
842                        _ => {}
843                    };
844                    None
845                },
846                &mut |e| match e {
847                    // Evaluate and pull up constants
848                    MirScalarExpr::Column(_, _)
849                    | MirScalarExpr::Literal(_, _)
850                    | MirScalarExpr::CallUnmaterializable(_) => (),
851                    MirScalarExpr::CallUnary { func, expr } => {
852                        if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
853                            *e = eval(e);
854                        } else if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
855                            if let MirScalarExpr::CallVariadic {
856                                func: VariadicFunc::RecordCreate { .. },
857                                exprs,
858                            } = &mut **expr
859                            {
860                                *e = exprs.swap_remove(i);
861                            }
862                        }
863                    }
864                    MirScalarExpr::CallBinary { func, expr1, expr2 } => {
865                        if expr1.is_literal() && expr2.is_literal() {
866                            *e = eval(e);
867                        } else if (expr1.is_literal_null() || expr2.is_literal_null())
868                            && func.propagates_nulls()
869                        {
870                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
871                        } else if let Some(err) = expr1.as_literal_err() {
872                            *e = MirScalarExpr::literal(
873                                Err(err.clone()),
874                                e.typ(column_types).scalar_type,
875                            );
876                        } else if let Some(err) = expr2.as_literal_err() {
877                            *e = MirScalarExpr::literal(
878                                Err(err.clone()),
879                                e.typ(column_types).scalar_type,
880                            );
881                        } else if let BinaryFunc::IsLikeMatch { case_insensitive } = func {
882                            if expr2.is_literal() {
883                                // We can at least precompile the regex.
884                                let pattern = expr2.as_literal_str().unwrap();
885                                *e = match like_pattern::compile(pattern, *case_insensitive) {
886                                    Ok(matcher) => expr1.take().call_unary(UnaryFunc::IsLikeMatch(
887                                        func::IsLikeMatch(matcher),
888                                    )),
889                                    Err(err) => MirScalarExpr::literal(
890                                        Err(err),
891                                        e.typ(column_types).scalar_type,
892                                    ),
893                                };
894                            }
895                        } else if let BinaryFunc::IsRegexpMatch { case_insensitive } = func {
896                            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
897                                *e = match Regex::new(
898                                    row.unpack_first().unwrap_str(),
899                                    *case_insensitive,
900                                ) {
901                                    Ok(regex) => expr1.take().call_unary(UnaryFunc::IsRegexpMatch(
902                                        func::IsRegexpMatch(regex),
903                                    )),
904                                    Err(err) => MirScalarExpr::literal(
905                                        Err(err.into()),
906                                        e.typ(column_types).scalar_type,
907                                    ),
908                                };
909                            }
910                        } else if *func == BinaryFunc::ExtractInterval && expr1.is_literal() {
911                            let units = expr1.as_literal_str().unwrap();
912                            *e = match units.parse::<DateTimeUnits>() {
913                                Ok(units) => MirScalarExpr::CallUnary {
914                                    func: UnaryFunc::ExtractInterval(func::ExtractInterval(units)),
915                                    expr: Box::new(expr2.take()),
916                                },
917                                Err(_) => MirScalarExpr::literal(
918                                    Err(EvalError::UnknownUnits(units.into())),
919                                    e.typ(column_types).scalar_type,
920                                ),
921                            }
922                        } else if *func == BinaryFunc::ExtractTime && expr1.is_literal() {
923                            let units = expr1.as_literal_str().unwrap();
924                            *e = match units.parse::<DateTimeUnits>() {
925                                Ok(units) => MirScalarExpr::CallUnary {
926                                    func: UnaryFunc::ExtractTime(func::ExtractTime(units)),
927                                    expr: Box::new(expr2.take()),
928                                },
929                                Err(_) => MirScalarExpr::literal(
930                                    Err(EvalError::UnknownUnits(units.into())),
931                                    e.typ(column_types).scalar_type,
932                                ),
933                            }
934                        } else if *func == BinaryFunc::ExtractTimestamp && expr1.is_literal() {
935                            let units = expr1.as_literal_str().unwrap();
936                            *e = match units.parse::<DateTimeUnits>() {
937                                Ok(units) => MirScalarExpr::CallUnary {
938                                    func: UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(
939                                        units,
940                                    )),
941                                    expr: Box::new(expr2.take()),
942                                },
943                                Err(_) => MirScalarExpr::literal(
944                                    Err(EvalError::UnknownUnits(units.into())),
945                                    e.typ(column_types).scalar_type,
946                                ),
947                            }
948                        } else if *func == BinaryFunc::ExtractTimestampTz && expr1.is_literal() {
949                            let units = expr1.as_literal_str().unwrap();
950                            *e = match units.parse::<DateTimeUnits>() {
951                                Ok(units) => MirScalarExpr::CallUnary {
952                                    func: UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(
953                                        units,
954                                    )),
955                                    expr: Box::new(expr2.take()),
956                                },
957                                Err(_) => MirScalarExpr::literal(
958                                    Err(EvalError::UnknownUnits(units.into())),
959                                    e.typ(column_types).scalar_type,
960                                ),
961                            }
962                        } else if *func == BinaryFunc::ExtractDate && expr1.is_literal() {
963                            let units = expr1.as_literal_str().unwrap();
964                            *e = match units.parse::<DateTimeUnits>() {
965                                Ok(units) => MirScalarExpr::CallUnary {
966                                    func: UnaryFunc::ExtractDate(func::ExtractDate(units)),
967                                    expr: Box::new(expr2.take()),
968                                },
969                                Err(_) => MirScalarExpr::literal(
970                                    Err(EvalError::UnknownUnits(units.into())),
971                                    e.typ(column_types).scalar_type,
972                                ),
973                            }
974                        } else if *func == BinaryFunc::DatePartInterval && expr1.is_literal() {
975                            let units = expr1.as_literal_str().unwrap();
976                            *e = match units.parse::<DateTimeUnits>() {
977                                Ok(units) => MirScalarExpr::CallUnary {
978                                    func: UnaryFunc::DatePartInterval(func::DatePartInterval(
979                                        units,
980                                    )),
981                                    expr: Box::new(expr2.take()),
982                                },
983                                Err(_) => MirScalarExpr::literal(
984                                    Err(EvalError::UnknownUnits(units.into())),
985                                    e.typ(column_types).scalar_type,
986                                ),
987                            }
988                        } else if *func == BinaryFunc::DatePartTime && expr1.is_literal() {
989                            let units = expr1.as_literal_str().unwrap();
990                            *e = match units.parse::<DateTimeUnits>() {
991                                Ok(units) => MirScalarExpr::CallUnary {
992                                    func: UnaryFunc::DatePartTime(func::DatePartTime(units)),
993                                    expr: Box::new(expr2.take()),
994                                },
995                                Err(_) => MirScalarExpr::literal(
996                                    Err(EvalError::UnknownUnits(units.into())),
997                                    e.typ(column_types).scalar_type,
998                                ),
999                            }
1000                        } else if *func == BinaryFunc::DatePartTimestamp && expr1.is_literal() {
1001                            let units = expr1.as_literal_str().unwrap();
1002                            *e = match units.parse::<DateTimeUnits>() {
1003                                Ok(units) => MirScalarExpr::CallUnary {
1004                                    func: UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(
1005                                        units,
1006                                    )),
1007                                    expr: Box::new(expr2.take()),
1008                                },
1009                                Err(_) => MirScalarExpr::literal(
1010                                    Err(EvalError::UnknownUnits(units.into())),
1011                                    e.typ(column_types).scalar_type,
1012                                ),
1013                            }
1014                        } else if *func == BinaryFunc::DatePartTimestampTz && expr1.is_literal() {
1015                            let units = expr1.as_literal_str().unwrap();
1016                            *e = match units.parse::<DateTimeUnits>() {
1017                                Ok(units) => MirScalarExpr::CallUnary {
1018                                    func: UnaryFunc::DatePartTimestampTz(
1019                                        func::DatePartTimestampTz(units),
1020                                    ),
1021                                    expr: Box::new(expr2.take()),
1022                                },
1023                                Err(_) => MirScalarExpr::literal(
1024                                    Err(EvalError::UnknownUnits(units.into())),
1025                                    e.typ(column_types).scalar_type,
1026                                ),
1027                            }
1028                        } else if *func == BinaryFunc::DateTruncTimestamp && expr1.is_literal() {
1029                            let units = expr1.as_literal_str().unwrap();
1030                            *e = match units.parse::<DateTimeUnits>() {
1031                                Ok(units) => MirScalarExpr::CallUnary {
1032                                    func: UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(
1033                                        units,
1034                                    )),
1035                                    expr: Box::new(expr2.take()),
1036                                },
1037                                Err(_) => MirScalarExpr::literal(
1038                                    Err(EvalError::UnknownUnits(units.into())),
1039                                    e.typ(column_types).scalar_type,
1040                                ),
1041                            }
1042                        } else if *func == BinaryFunc::DateTruncTimestampTz && expr1.is_literal() {
1043                            let units = expr1.as_literal_str().unwrap();
1044                            *e = match units.parse::<DateTimeUnits>() {
1045                                Ok(units) => MirScalarExpr::CallUnary {
1046                                    func: UnaryFunc::DateTruncTimestampTz(
1047                                        func::DateTruncTimestampTz(units),
1048                                    ),
1049                                    expr: Box::new(expr2.take()),
1050                                },
1051                                Err(_) => MirScalarExpr::literal(
1052                                    Err(EvalError::UnknownUnits(units.into())),
1053                                    e.typ(column_types).scalar_type,
1054                                ),
1055                            }
1056                        } else if *func == BinaryFunc::TimezoneTimestamp && expr1.is_literal() {
1057                            // If the timezone argument is a literal, and we're applying the function on many rows at the same
1058                            // time we really don't want to parse it again and again, so we parse it once and embed it into the
1059                            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
1060                            let tz = expr1.as_literal_str().unwrap();
1061                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1062                                Ok(tz) => MirScalarExpr::CallUnary {
1063                                    func: UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz)),
1064                                    expr: Box::new(expr2.take()),
1065                                },
1066                                Err(err) => MirScalarExpr::literal(
1067                                    Err(err),
1068                                    e.typ(column_types).scalar_type,
1069                                ),
1070                            }
1071                        } else if *func == BinaryFunc::TimezoneTimestampTz && expr1.is_literal() {
1072                            let tz = expr1.as_literal_str().unwrap();
1073                            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1074                                Ok(tz) => MirScalarExpr::CallUnary {
1075                                    func: UnaryFunc::TimezoneTimestampTz(
1076                                        func::TimezoneTimestampTz(tz),
1077                                    ),
1078                                    expr: Box::new(expr2.take()),
1079                                },
1080                                Err(err) => MirScalarExpr::literal(
1081                                    Err(err),
1082                                    e.typ(column_types).scalar_type,
1083                                ),
1084                            }
1085                        } else if *func == BinaryFunc::ToCharTimestamp && expr2.is_literal() {
1086                            let format_str = expr2.as_literal_str().unwrap();
1087                            *e = MirScalarExpr::CallUnary {
1088                                func: UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
1089                                    format_string: format_str.to_string(),
1090                                    format: DateTimeFormat::compile(format_str),
1091                                }),
1092                                expr: Box::new(expr1.take()),
1093                            };
1094                        } else if *func == BinaryFunc::ToCharTimestampTz && expr2.is_literal() {
1095                            let format_str = expr2.as_literal_str().unwrap();
1096                            *e = MirScalarExpr::CallUnary {
1097                                func: UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
1098                                    format_string: format_str.to_string(),
1099                                    format: DateTimeFormat::compile(format_str),
1100                                }),
1101                                expr: Box::new(expr1.take()),
1102                            };
1103                        } else if matches!(*func, BinaryFunc::Eq | BinaryFunc::NotEq)
1104                            && expr2 < expr1
1105                        {
1106                            // Canonically order elements so that deduplication works better.
1107                            // Also, the below `Literal([c1, c2]) = record_create(e1, e2)` matching
1108                            // relies on this canonical ordering.
1109                            mem::swap(expr1, expr2);
1110                        } else if let (
1111                            BinaryFunc::Eq,
1112                            MirScalarExpr::Literal(
1113                                Ok(lit_row),
1114                                ColumnType {
1115                                    scalar_type:
1116                                        ScalarType::Record {
1117                                            fields: field_types,
1118                                            ..
1119                                        },
1120                                    ..
1121                                },
1122                            ),
1123                            MirScalarExpr::CallVariadic {
1124                                func: VariadicFunc::RecordCreate { .. },
1125                                exprs: rec_create_args,
1126                            },
1127                        ) = (&*func, &**expr1, &**expr2)
1128                        {
1129                            // Literal([c1, c2]) = record_create(e1, e2)
1130                            //  -->
1131                            // c1 = e1 AND c2 = e2
1132                            //
1133                            // (Records are represented as lists.)
1134                            //
1135                            // `MapFilterProject::literal_constraints` relies on this transform,
1136                            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
1137                            match lit_row.unpack_first() {
1138                                Datum::List(datum_list) => {
1139                                    *e = MirScalarExpr::CallVariadic {
1140                                        func: VariadicFunc::And,
1141                                        exprs: itertools::izip!(
1142                                            datum_list.iter(),
1143                                            field_types,
1144                                            rec_create_args
1145                                        )
1146                                        .map(|(d, (_, typ), a)| MirScalarExpr::CallBinary {
1147                                            func: BinaryFunc::Eq,
1148                                            expr1: Box::new(MirScalarExpr::Literal(
1149                                                Ok(Row::pack_slice(&[d])),
1150                                                typ.clone(),
1151                                            )),
1152                                            expr2: Box::new(a.clone()),
1153                                        })
1154                                        .collect(),
1155                                    };
1156                                }
1157                                _ => {}
1158                            }
1159                        } else if let (
1160                            BinaryFunc::Eq,
1161                            MirScalarExpr::CallVariadic {
1162                                func: VariadicFunc::RecordCreate { .. },
1163                                exprs: rec_create_args1,
1164                            },
1165                            MirScalarExpr::CallVariadic {
1166                                func: VariadicFunc::RecordCreate { .. },
1167                                exprs: rec_create_args2,
1168                            },
1169                        ) = (&*func, &**expr1, &**expr2)
1170                        {
1171                            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
1172                            //  -->
1173                            // a1 = b1 AND a2 = b2 AND ...
1174                            //
1175                            // This is similar to the previous reduction, but this one kicks in also
1176                            // when only some (or none) of the record fields are literals. This
1177                            // enables the discovery of literal constraints for those fields.
1178                            //
1179                            // Note that there is a similar decomposition in
1180                            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
1181                            // pipeline than the compilation of IN lists to `record_create`.
1182                            *e = MirScalarExpr::CallVariadic {
1183                                func: VariadicFunc::And,
1184                                exprs: rec_create_args1
1185                                    .into_iter()
1186                                    .zip(rec_create_args2)
1187                                    .map(|(a, b)| MirScalarExpr::CallBinary {
1188                                        func: BinaryFunc::Eq,
1189                                        expr1: Box::new(a.clone()),
1190                                        expr2: Box::new(b.clone()),
1191                                    })
1192                                    .collect(),
1193                            }
1194                        }
1195                    }
1196                    MirScalarExpr::CallVariadic { .. } => {
1197                        e.flatten_associative();
1198                        let (func, exprs) = match e {
1199                            MirScalarExpr::CallVariadic { func, exprs } => (func, exprs),
1200                            _ => unreachable!("`flatten_associative` shouldn't change node type"),
1201                        };
1202                        if *func == VariadicFunc::Coalesce {
1203                            // If all inputs are null, output is null. This check must
1204                            // be done before `exprs.retain...` because `e.typ` requires
1205                            // > 0 `exprs` remain.
1206                            if exprs.iter().all(|expr| expr.is_literal_null()) {
1207                                *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1208                                return;
1209                            }
1210
1211                            // Remove any null values if not all values are null.
1212                            exprs.retain(|e| !e.is_literal_null());
1213
1214                            // Find the first argument that is a literal or non-nullable
1215                            // column. All arguments after it get ignored, so throw them
1216                            // away. This intentionally throws away errors that can
1217                            // never happen.
1218                            if let Some(i) = exprs
1219                                .iter()
1220                                .position(|e| e.is_literal() || !e.typ(column_types).nullable)
1221                            {
1222                                exprs.truncate(i + 1);
1223                            }
1224
1225                            // Deduplicate arguments in cases like `coalesce(#0, #0)`.
1226                            let mut prior_exprs = BTreeSet::new();
1227                            exprs.retain(|e| prior_exprs.insert(e.clone()));
1228
1229                            if let Some(expr) = exprs.iter_mut().find(|e| e.is_literal_err()) {
1230                                // One of the remaining arguments is an error, so
1231                                // just replace the entire coalesce with that error.
1232                                *e = expr.take();
1233                            } else if exprs.len() == 1 {
1234                                // Only one argument, so the coalesce is a no-op.
1235                                *e = exprs[0].take();
1236                            }
1237                        } else if exprs.iter().all(|e| e.is_literal()) {
1238                            *e = eval(e);
1239                        } else if func.propagates_nulls()
1240                            && exprs.iter().any(|e| e.is_literal_null())
1241                        {
1242                            *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
1243                        } else if let Some(err) = exprs.iter().find_map(|e| e.as_literal_err()) {
1244                            *e = MirScalarExpr::literal(
1245                                Err(err.clone()),
1246                                e.typ(column_types).scalar_type,
1247                            );
1248                        } else if *func == VariadicFunc::RegexpMatch
1249                            && exprs[1].is_literal()
1250                            && exprs.get(2).map_or(true, |e| e.is_literal())
1251                        {
1252                            let needle = exprs[1].as_literal_str().unwrap();
1253                            let flags = match exprs.len() {
1254                                3 => exprs[2].as_literal_str().unwrap(),
1255                                _ => "",
1256                            };
1257                            *e = match func::build_regex(needle, flags) {
1258                                Ok(regex) => mem::take(exprs)
1259                                    .into_first()
1260                                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
1261                                Err(err) => MirScalarExpr::literal(
1262                                    Err(err),
1263                                    e.typ(column_types).scalar_type,
1264                                ),
1265                            };
1266                        } else if *func == VariadicFunc::RegexpReplace
1267                            && exprs[1].is_literal()
1268                            && exprs.get(3).map_or(true, |e| e.is_literal())
1269                        {
1270                            let pattern = exprs[1].as_literal_str().unwrap();
1271                            let flags = exprs
1272                                .get(3)
1273                                .map_or("", |expr| expr.as_literal_str().unwrap());
1274                            let (limit, flags) = regexp_replace_parse_flags(flags);
1275
1276                            // The behavior of `regexp_replace` is that if the data is `NULL`, the
1277                            // function returns `NULL`, independently of whether the pattern or
1278                            // flags are correct. We need to check for this case and introduce an
1279                            // if-then-else on the error path to only surface the error if the first
1280                            // input is non-NULL.
1281                            *e = match func::build_regex(pattern, &flags) {
1282                                Ok(regex) => {
1283                                    let mut exprs = mem::take(exprs);
1284                                    let replacement = exprs.swap_remove(2);
1285                                    let source = exprs.swap_remove(0);
1286                                    source.call_binary(
1287                                        replacement,
1288                                        BinaryFunc::RegexpReplace { regex, limit },
1289                                    )
1290                                }
1291                                Err(err) => {
1292                                    let mut exprs = mem::take(exprs);
1293                                    let source = exprs.swap_remove(0);
1294                                    let scalar_type = e.typ(column_types).scalar_type;
1295                                    // We need to return `NULL` on `NULL` input, and error otherwise.
1296                                    source.call_is_null().if_then_else(
1297                                        MirScalarExpr::literal_null(scalar_type.clone()),
1298                                        MirScalarExpr::literal(Err(err), scalar_type),
1299                                    )
1300                                }
1301                            };
1302                        } else if *func == VariadicFunc::RegexpSplitToArray
1303                            && exprs[1].is_literal()
1304                            && exprs.get(2).map_or(true, |e| e.is_literal())
1305                        {
1306                            let needle = exprs[1].as_literal_str().unwrap();
1307                            let flags = match exprs.len() {
1308                                3 => exprs[2].as_literal_str().unwrap(),
1309                                _ => "",
1310                            };
1311                            *e = match func::build_regex(needle, flags) {
1312                                Ok(regex) => mem::take(exprs).into_first().call_unary(
1313                                    UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(regex)),
1314                                ),
1315                                Err(err) => MirScalarExpr::literal(
1316                                    Err(err),
1317                                    e.typ(column_types).scalar_type,
1318                                ),
1319                            };
1320                        } else if *func == VariadicFunc::ListIndex && is_list_create_call(&exprs[0])
1321                        {
1322                            // We are looking for ListIndex(ListCreate, literal), and eliminate
1323                            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
1324                            let ind_exprs = exprs.split_off(1);
1325                            let top_list_create = exprs.swap_remove(0);
1326                            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
1327                        } else if *func == VariadicFunc::Or || *func == VariadicFunc::And {
1328                            // Note: It's important that we have called `flatten_associative` above.
1329                            e.undistribute_and_or();
1330                            e.reduce_and_canonicalize_and_or();
1331                        } else if let VariadicFunc::TimezoneTime = func {
1332                            if exprs[0].is_literal() && exprs[2].is_literal_ok() {
1333                                let tz = exprs[0].as_literal_str().unwrap();
1334                                *e = match parse_timezone(tz, TimezoneSpec::Posix) {
1335                                    Ok(tz) => MirScalarExpr::CallUnary {
1336                                        func: UnaryFunc::TimezoneTime(func::TimezoneTime {
1337                                            tz,
1338                                            wall_time: exprs[2]
1339                                                .as_literal()
1340                                                .unwrap()
1341                                                .unwrap()
1342                                                .unwrap_timestamptz()
1343                                                .naive_utc(),
1344                                        }),
1345                                        expr: Box::new(exprs[1].take()),
1346                                    },
1347                                    Err(err) => MirScalarExpr::literal(
1348                                        Err(err),
1349                                        e.typ(column_types).scalar_type,
1350                                    ),
1351                                }
1352                            }
1353                        }
1354                    }
1355                    MirScalarExpr::If { cond, then, els } => {
1356                        if let Some(literal) = cond.as_literal() {
1357                            match literal {
1358                                Ok(Datum::True) => *e = then.take(),
1359                                Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
1360                                Err(err) => {
1361                                    *e = MirScalarExpr::Literal(
1362                                        Err(err.clone()),
1363                                        then.typ(column_types)
1364                                            .union(&els.typ(column_types))
1365                                            .unwrap(),
1366                                    )
1367                                }
1368                                _ => unreachable!(),
1369                            }
1370                        } else if then == els {
1371                            *e = then.take();
1372                        } else if then.is_literal_ok()
1373                            && els.is_literal_ok()
1374                            && then.typ(column_types).scalar_type == ScalarType::Bool
1375                            && els.typ(column_types).scalar_type == ScalarType::Bool
1376                        {
1377                            match (then.as_literal(), els.as_literal()) {
1378                                // Note: NULLs from the condition should not be propagated to the result
1379                                // of the expression.
1380                                (Some(Ok(Datum::True)), _) => {
1381                                    // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
1382                                    // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
1383                                    *e = cond
1384                                        .clone()
1385                                        .call_is_null()
1386                                        .not()
1387                                        .and(cond.take())
1388                                        .or(els.take());
1389                                }
1390                                (Some(Ok(Datum::False)), _) => {
1391                                    // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
1392                                    // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
1393                                    *e = cond
1394                                        .clone()
1395                                        .not()
1396                                        .or(cond.take().call_is_null())
1397                                        .and(els.take());
1398                                }
1399                                (_, Some(Ok(Datum::True))) => {
1400                                    // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
1401                                    // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
1402                                    *e = cond
1403                                        .clone()
1404                                        .not()
1405                                        .or(cond.take().call_is_null())
1406                                        .or(then.take());
1407                                }
1408                                (_, Some(Ok(Datum::False))) => {
1409                                    // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
1410                                    // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
1411                                    *e = cond
1412                                        .clone()
1413                                        .call_is_null()
1414                                        .not()
1415                                        .and(cond.take())
1416                                        .and(then.take());
1417                                }
1418                                _ => {}
1419                            }
1420                        } else {
1421                            // Equivalent expression structure would allow us to push the `If` into the expression.
1422                            // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
1423                            //
1424                            // We have to also make sure that the expressions that will end up in
1425                            // the two `If` branches have unionable types. Otherwise, the `If` could
1426                            // not be typed by `typ`. An example where this could cause an issue is
1427                            // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
1428                            // range of input types. (In theory, we could still do the optimization
1429                            // in this case by inserting appropriate casts, but this corner case is
1430                            // not worth the complication for now.)
1431                            // See https://github.com/MaterializeInc/database-issues/issues/9182
1432                            match (&mut **then, &mut **els) {
1433                                (
1434                                    MirScalarExpr::CallUnary { func: f1, expr: e1 },
1435                                    MirScalarExpr::CallUnary { func: f2, expr: e2 },
1436                                ) if f1 == f2
1437                                    && e1
1438                                        .typ(column_types)
1439                                        .union(&e2.typ(column_types))
1440                                        .is_ok() =>
1441                                {
1442                                    *e = cond
1443                                        .take()
1444                                        .if_then_else(e1.take(), e2.take())
1445                                        .call_unary(f1.clone());
1446                                }
1447                                (
1448                                    MirScalarExpr::CallBinary {
1449                                        func: f1,
1450                                        expr1: e1a,
1451                                        expr2: e2a,
1452                                    },
1453                                    MirScalarExpr::CallBinary {
1454                                        func: f2,
1455                                        expr1: e1b,
1456                                        expr2: e2b,
1457                                    },
1458                                ) if f1 == f2
1459                                    && e1a == e1b
1460                                    && e2a
1461                                        .typ(column_types)
1462                                        .union(&e2b.typ(column_types))
1463                                        .is_ok() =>
1464                                {
1465                                    *e = e1a.take().call_binary(
1466                                        cond.take().if_then_else(e2a.take(), e2b.take()),
1467                                        f1.clone(),
1468                                    );
1469                                }
1470                                (
1471                                    MirScalarExpr::CallBinary {
1472                                        func: f1,
1473                                        expr1: e1a,
1474                                        expr2: e2a,
1475                                    },
1476                                    MirScalarExpr::CallBinary {
1477                                        func: f2,
1478                                        expr1: e1b,
1479                                        expr2: e2b,
1480                                    },
1481                                ) if f1 == f2
1482                                    && e2a == e2b
1483                                    && e1a
1484                                        .typ(column_types)
1485                                        .union(&e1b.typ(column_types))
1486                                        .is_ok() =>
1487                                {
1488                                    *e = cond
1489                                        .take()
1490                                        .if_then_else(e1a.take(), e1b.take())
1491                                        .call_binary(e2a.take(), f1.clone());
1492                                }
1493                                _ => {}
1494                            }
1495                        }
1496                    }
1497                },
1498            );
1499        }
1500
1501        /* #region `reduce_list_create_list_index_literal` and helper functions */
1502
1503        fn list_create_type(list_create: &MirScalarExpr) -> ScalarType {
1504            if let MirScalarExpr::CallVariadic {
1505                func: VariadicFunc::ListCreate { elem_type: typ },
1506                ..
1507            } = list_create
1508            {
1509                (*typ).clone()
1510            } else {
1511                unreachable!()
1512            }
1513        }
1514
1515        fn is_list_create_call(expr: &MirScalarExpr) -> bool {
1516            matches!(
1517                expr,
1518                MirScalarExpr::CallVariadic {
1519                    func: VariadicFunc::ListCreate { .. },
1520                    ..
1521                }
1522            )
1523        }
1524
1525        /// Partial-evaluates a list indexing with a literal directly after a list creation.
1526        ///
1527        /// Multi-dimensional lists are handled by a single call to this function, with multiple
1528        /// elements in index_exprs (of which not all need to be literals), and nested ListCreates
1529        /// in list_create_to_reduce.
1530        ///
1531        /// # Examples
1532        ///
1533        /// `LIST[f1,f2][2]` --> `f2`.
1534        ///
1535        /// A multi-dimensional list, with only some of the indexes being literals:
1536        /// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
1537        ///
1538        /// See more examples in list.slt.
1539        fn reduce_list_create_list_index_literal(
1540            mut list_create_to_reduce: MirScalarExpr,
1541            mut index_exprs: Vec<MirScalarExpr>,
1542        ) -> MirScalarExpr {
1543            // We iterate over the index_exprs and remove literals, but keep non-literals.
1544            // When we encounter a non-literal, we need to dig into the nested ListCreates:
1545            // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
1546            // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
1547            // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
1548            // that are at the current level (except those that disappeared due to
1549            // literals at earlier levels), index into them with the literal, and change each
1550            // element in `list_create_mut_refs` to the result.
1551            // We also record mut refs to all the earlier `element_type` references that we have
1552            // seen in ListCreate calls, because when we process a literal index, we need to remove
1553            // one layer of list type from all these earlier ListCreate `element_type`s.
1554            let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
1555            let mut earlier_list_create_types: Vec<&mut ScalarType> = vec![];
1556            let mut i = 0;
1557            while i < index_exprs.len()
1558                && list_create_mut_refs
1559                    .iter()
1560                    .all(|lc| is_list_create_call(lc))
1561            {
1562                if index_exprs[i].is_literal_ok() {
1563                    // We can remove this index.
1564                    let removed_index = index_exprs.remove(i);
1565                    let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
1566                        Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
1567                        _ => unreachable!(), // always an Int64, see plan_index_list
1568                    };
1569                    // For each list_create referenced by list_create_mut_refs, substitute it by its
1570                    // `index`th argument (or null).
1571                    for list_create in &mut list_create_mut_refs {
1572                        let list_create_args = match list_create {
1573                            MirScalarExpr::CallVariadic {
1574                                func: VariadicFunc::ListCreate { elem_type: _ },
1575                                exprs,
1576                            } => exprs,
1577                            _ => unreachable!(), // func cannot be anything else than a ListCreate
1578                        };
1579                        // ListIndex gives null on an out-of-bounds index
1580                        if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap()
1581                        {
1582                            let index: usize = index_i64.try_into().unwrap();
1583                            **list_create = list_create_args.swap_remove(index);
1584                        } else {
1585                            let typ = list_create_type(list_create);
1586                            **list_create = MirScalarExpr::literal_null(typ);
1587                        }
1588                    }
1589                    // Peel one layer off of each of the earlier element types.
1590                    for t in earlier_list_create_types.iter_mut() {
1591                        if let ScalarType::List {
1592                            element_type,
1593                            custom_id: _,
1594                        } = t
1595                        {
1596                            **t = *element_type.clone();
1597                            // These are not the same types anymore, so remove custom_ids all the
1598                            // way down.
1599                            let mut u = &mut **t;
1600                            while let ScalarType::List {
1601                                element_type,
1602                                custom_id,
1603                            } = u
1604                            {
1605                                *custom_id = None;
1606                                u = &mut **element_type;
1607                            }
1608                        } else {
1609                            unreachable!("already matched below");
1610                        }
1611                    }
1612                } else {
1613                    // We can't remove this index, so we can't reduce any of the ListCreates at this
1614                    // level. So we change list_create_mut_refs to refer to all the arguments of all
1615                    // the ListCreates currently referenced by list_create_mut_refs.
1616                    list_create_mut_refs = list_create_mut_refs
1617                        .into_iter()
1618                        .flat_map(|list_create| match list_create {
1619                            MirScalarExpr::CallVariadic {
1620                                func: VariadicFunc::ListCreate { elem_type },
1621                                exprs: list_create_args,
1622                            } => {
1623                                earlier_list_create_types.push(elem_type);
1624                                list_create_args
1625                            }
1626                            // func cannot be anything else than a ListCreate
1627                            _ => unreachable!(),
1628                        })
1629                        .collect();
1630                    i += 1; // next index_expr
1631                }
1632            }
1633            // If all list indexes have been evaluated, return the reduced expression.
1634            // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
1635            if index_exprs.is_empty() {
1636                assert_eq!(list_create_mut_refs.len(), 1);
1637                list_create_to_reduce
1638            } else {
1639                let mut exprs: Vec<MirScalarExpr> = vec![list_create_to_reduce];
1640                exprs.append(&mut index_exprs);
1641                MirScalarExpr::CallVariadic {
1642                    func: VariadicFunc::ListIndex,
1643                    exprs,
1644                }
1645            }
1646        }
1647
1648        /* #endregion */
1649    }
1650
1651    /// Decompose an IsNull expression into a disjunction of
1652    /// simpler expressions.
1653    ///
1654    /// Assumes that `self` is the expression inside of an IsNull.
1655    /// Returns `Some(expressions)` if the outer IsNull is to be
1656    /// replaced by some other expression. Note: if it returns
1657    /// None, it might still have mutated *self.
1658    fn decompose_is_null(&mut self) -> Option<MirScalarExpr> {
1659        // TODO: allow simplification of unmaterializable functions
1660
1661        match self {
1662            MirScalarExpr::CallUnary {
1663                func,
1664                expr: inner_expr,
1665            } => {
1666                if !func.introduces_nulls() {
1667                    if func.propagates_nulls() {
1668                        *self = inner_expr.take();
1669                        return self.decompose_is_null();
1670                    } else {
1671                        // Different from CallBinary and CallVariadic, because of determinism. See
1672                        // https://materializeinc.slack.com/archives/C01BE3RN82F/p1657644478517709
1673                        return Some(MirScalarExpr::literal_false());
1674                    }
1675                }
1676            }
1677            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
1678                // (<expr1> <op> <expr2>) IS NULL can often be simplified to
1679                // (<expr1> IS NULL) OR (<expr2> IS NULL).
1680                if func.propagates_nulls() && !func.introduces_nulls() {
1681                    let expr1 = expr1.take().call_is_null();
1682                    let expr2 = expr2.take().call_is_null();
1683                    return Some(expr1.or(expr2));
1684                }
1685            }
1686            MirScalarExpr::CallVariadic { func, exprs } => {
1687                if func.propagates_nulls() && !func.introduces_nulls() {
1688                    let exprs = exprs.into_iter().map(|e| e.take().call_is_null()).collect();
1689                    return Some(MirScalarExpr::CallVariadic {
1690                        func: VariadicFunc::Or,
1691                        exprs,
1692                    });
1693                }
1694            }
1695            _ => {}
1696        }
1697
1698        None
1699    }
1700
1701    /// Flattens a chain of calls to associative variadic functions
1702    /// (For example: ORs or ANDs)
1703    pub fn flatten_associative(&mut self) {
1704        match self {
1705            MirScalarExpr::CallVariadic {
1706                exprs: outer_operands,
1707                func: outer_func,
1708            } if outer_func.is_associative() => {
1709                *outer_operands = outer_operands
1710                    .into_iter()
1711                    .flat_map(|o| {
1712                        if let MirScalarExpr::CallVariadic {
1713                            exprs: inner_operands,
1714                            func: inner_func,
1715                        } = o
1716                        {
1717                            if *inner_func == *outer_func {
1718                                mem::take(inner_operands)
1719                            } else {
1720                                vec![o.take()]
1721                            }
1722                        } else {
1723                            vec![o.take()]
1724                        }
1725                    })
1726                    .collect();
1727            }
1728            _ => {}
1729        }
1730    }
1731
1732    /* #region AND/OR canonicalization and transformations  */
1733
1734    /// Canonicalizes AND/OR, and does some straightforward simplifications
1735    fn reduce_and_canonicalize_and_or(&mut self) {
1736        // We do this until fixed point, because after undistribute_and_or calls us, it relies on
1737        // the property that self is not an 1-arg AND/OR. Just one application of our loop body
1738        // can't ensure this, because the application itself might create a 1-arg AND/OR.
1739        let mut old_self = MirScalarExpr::column(0);
1740        while old_self != *self {
1741            old_self = self.clone();
1742            match self {
1743                MirScalarExpr::CallVariadic {
1744                    func: func @ (VariadicFunc::And | VariadicFunc::Or),
1745                    exprs,
1746                } => {
1747                    // Canonically order elements so that various deduplications work better,
1748                    // e.g., in undistribute_and_or.
1749                    // Also, extract_equal_or_both_null_inner depends on the args being sorted.
1750                    exprs.sort();
1751
1752                    // x AND/OR x --> x
1753                    exprs.dedup(); // this also needs the above sorting
1754
1755                    if exprs.len() == 1 {
1756                        // AND/OR of 1 argument evaluates to that argument
1757                        *self = exprs.swap_remove(0);
1758                    } else if exprs.len() == 0 {
1759                        // AND/OR of 0 arguments evaluates to true/false
1760                        *self = func.unit_of_and_or();
1761                    } else if exprs.iter().any(|e| *e == func.zero_of_and_or()) {
1762                        // short-circuiting
1763                        *self = func.zero_of_and_or();
1764                    } else {
1765                        // a AND true --> a
1766                        // a OR false --> a
1767                        exprs.retain(|e| *e != func.unit_of_and_or());
1768                    }
1769                }
1770                _ => {}
1771            }
1772        }
1773    }
1774
1775    /// Transforms !(a && b) into !a || !b, and !(a || b) into !a && !b
1776    fn demorgans(&mut self) {
1777        if let MirScalarExpr::CallUnary {
1778            expr: inner,
1779            func: UnaryFunc::Not(func::Not),
1780        } = self
1781        {
1782            inner.flatten_associative();
1783            match &mut **inner {
1784                MirScalarExpr::CallVariadic {
1785                    func: inner_func @ (VariadicFunc::And | VariadicFunc::Or),
1786                    exprs,
1787                } => {
1788                    *inner_func = inner_func.switch_and_or();
1789                    *exprs = exprs.into_iter().map(|e| e.take().not()).collect();
1790                    *self = (*inner).take(); // Removes the outer not
1791                }
1792                _ => {}
1793            }
1794        }
1795    }
1796
1797    /// AND/OR undistribution (factoring out) to apply at each `MirScalarExpr`.
1798    ///
1799    /// This method attempts to apply one of the [distribution laws][distributivity]
1800    /// (in a direction opposite to the their name):
1801    /// ```text
1802    /// (a && b) || (a && c) --> a && (b || c)  // Undistribute-OR
1803    /// (a || b) && (a || c) --> a || (b && c)  // Undistribute-AND
1804    /// ```
1805    /// or one of their corresponding two [absorption law][absorption] special
1806    /// cases:
1807    /// ```text
1808    /// a || (a && c)  -->  a  // Absorb-OR
1809    /// a && (a || c)  -->  a  // Absorb-AND
1810    /// ```
1811    ///
1812    /// The method also works with more than 2 arguments at the top, e.g.
1813    /// ```text
1814    /// (a && b) || (a && c) || (a && d)  -->  a && (b || c || d)
1815    /// ```
1816    /// It can also factor out only a subset of the top arguments, e.g.
1817    /// ```text
1818    /// (a && b) || (a && c) || (d && e)  -->  (a && (b || c)) || (d && e)
1819    /// ```
1820    ///
1821    /// Note that sometimes there are two overlapping possibilities to factor
1822    /// out from, e.g.
1823    /// ```text
1824    /// (a && b) || (a && c) || (d && c)
1825    /// ```
1826    /// Here we can factor out `a` from from the 1. and 2. terms, or we can
1827    /// factor out `c` from the 2. and 3. terms. One of these might lead to
1828    /// more/better undistribution opportunities later, but we just pick one
1829    /// locally, because recursively trying out all of them would lead to
1830    /// exponential run time.
1831    ///
1832    /// The local heuristic is that we prefer a candidate that leads to an
1833    /// absorption, or if there is no such one then we simply pick the first. In
1834    /// case of multiple absorption candidates, it doesn't matter which one we
1835    /// pick, because applying an absorption cannot adversely effect the
1836    /// possibility of applying other absorptions.
1837    ///
1838    /// # Assumption
1839    ///
1840    /// Assumes that nested chains of AND/OR applications are flattened (this
1841    /// can be enforced with [`Self::flatten_associative`]).
1842    ///
1843    /// # Examples
1844    ///
1845    /// Absorb-OR:
1846    /// ```text
1847    /// a || (a && c) || (a && d)
1848    /// -->
1849    /// a && (true || c || d)
1850    /// -->
1851    /// a && true
1852    /// -->
1853    /// a
1854    /// ```
1855    /// Here only the first step is performed by this method. The rest is done
1856    /// by [`Self::reduce_and_canonicalize_and_or`] called after us in
1857    /// `reduce()`.
1858    ///
1859    /// [distributivity]: https://en.wikipedia.org/wiki/Distributive_property
1860    /// [absorption]: https://en.wikipedia.org/wiki/Absorption_law
1861    fn undistribute_and_or(&mut self) {
1862        // It wouldn't be strictly necessary to wrap this fn in this loop, because `reduce()` calls
1863        // us in a loop anyway. However, `reduce()` tries to do many other things, so the loop here
1864        // improves performance when there are several undistributions to apply in sequence, which
1865        // can occur in `CanonicalizeMfp` when undoing the DNF.
1866        let mut old_self = MirScalarExpr::column(0);
1867        while old_self != *self {
1868            old_self = self.clone();
1869            self.reduce_and_canonicalize_and_or(); // We don't want to deal with 1-arg AND/OR at the top
1870            if let MirScalarExpr::CallVariadic {
1871                exprs: outer_operands,
1872                func: outer_func @ (VariadicFunc::Or | VariadicFunc::And),
1873            } = self
1874            {
1875                let inner_func = outer_func.switch_and_or();
1876
1877                // Make sure that each outer operand is a call to inner_func, by wrapping in a 1-arg
1878                // call if necessary.
1879                outer_operands.iter_mut().for_each(|o| {
1880                    if !matches!(o, MirScalarExpr::CallVariadic {func: f, ..} if *f == inner_func) {
1881                        *o = MirScalarExpr::CallVariadic {
1882                            func: inner_func.clone(),
1883                            exprs: vec![o.take()],
1884                        };
1885                    }
1886                });
1887
1888                let mut inner_operands_refs: Vec<&mut Vec<MirScalarExpr>> = outer_operands
1889                    .iter_mut()
1890                    .map(|o| match o {
1891                        MirScalarExpr::CallVariadic { func: f, exprs } if *f == inner_func => exprs,
1892                        _ => unreachable!(), // the wrapping made sure that we'll get a match
1893                    })
1894                    .collect();
1895
1896                // Find inner operands to undistribute, i.e., which are in _all_ of the outer operands.
1897                let mut intersection = inner_operands_refs
1898                    .iter()
1899                    .map(|v| (*v).clone())
1900                    .reduce(|ops1, ops2| ops1.into_iter().filter(|e| ops2.contains(e)).collect())
1901                    .unwrap();
1902                intersection.sort();
1903                intersection.dedup();
1904
1905                if !intersection.is_empty() {
1906                    // Factor out the intersection from all the top-level args.
1907
1908                    // Remove the intersection from each inner operand vector.
1909                    inner_operands_refs
1910                        .iter_mut()
1911                        .for_each(|ops| (**ops).retain(|o| !intersection.contains(o)));
1912
1913                    // Simplify terms that now have only 0 or 1 args due to removing the intersection.
1914                    outer_operands
1915                        .iter_mut()
1916                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1917
1918                    // Add the intersection at the beginning
1919                    *self = MirScalarExpr::CallVariadic {
1920                        func: inner_func,
1921                        exprs: intersection.into_iter().chain_one(self.clone()).collect(),
1922                    };
1923                } else {
1924                    // If the intersection was empty, that means that there is nothing we can factor out
1925                    // from _all_ the top-level args. However, we might still find something to factor
1926                    // out from a subset of the top-level args. To find such an opportunity, we look for
1927                    // duplicates across all inner args, e.g. if we have
1928                    // `(...) OR (... AND `a` AND ...) OR (...) OR (... AND `a` AND ...)`
1929                    // then we'll find that `a` occurs in more than one top-level arg, so
1930                    // `indexes_to_undistribute` will point us to the 2. and 4. top-level args.
1931
1932                    // Create (inner_operand, index) pairs, where the index is the position in
1933                    // outer_operands
1934                    let all_inner_operands = inner_operands_refs
1935                        .iter()
1936                        .enumerate()
1937                        .flat_map(|(i, inner_vec)| inner_vec.iter().map(move |a| ((*a).clone(), i)))
1938                        .sorted()
1939                        .collect_vec();
1940
1941                    // Find inner operand expressions that occur in more than one top-level arg.
1942                    // Each inner vector in `undistribution_opportunities` will belong to one such inner
1943                    // operand expression, and it is a set of indexes pointing to top-level args where
1944                    // that inner operand occurs.
1945                    let undistribution_opportunities = all_inner_operands
1946                        .iter()
1947                        .chunk_by(|(a, _i)| a)
1948                        .into_iter()
1949                        .map(|(_a, g)| g.map(|(_a, i)| *i).sorted().dedup().collect_vec())
1950                        .filter(|g| g.len() > 1)
1951                        .collect_vec();
1952
1953                    // Choose one of the inner vectors from `undistribution_opportunities`.
1954                    let indexes_to_undistribute = undistribution_opportunities
1955                        .iter()
1956                        // Let's prefer index sets that directly lead to an absorption.
1957                        .find(|index_set| {
1958                            index_set
1959                                .iter()
1960                                .any(|i| inner_operands_refs.get(*i).unwrap().len() == 1)
1961                        })
1962                        // If we didn't find any absorption, then any index set will do.
1963                        .or_else(|| undistribution_opportunities.first())
1964                        .cloned();
1965
1966                    // In any case, undo the 1-arg wrapping that we did at the beginning.
1967                    outer_operands
1968                        .iter_mut()
1969                        .for_each(|o| o.reduce_and_canonicalize_and_or());
1970
1971                    if let Some(indexes_to_undistribute) = indexes_to_undistribute {
1972                        // Found something to undistribute from a subset of the outer operands.
1973                        // We temporarily remove these from outer_operands, call ourselves on it, and
1974                        // then push back the result.
1975                        let mut undistribute_from = MirScalarExpr::CallVariadic {
1976                            func: outer_func.clone(),
1977                            exprs: swap_remove_multiple(outer_operands, indexes_to_undistribute),
1978                        };
1979                        // By construction, the recursive call is guaranteed to hit
1980                        // the `!intersection.is_empty()` branch.
1981                        undistribute_from.undistribute_and_or();
1982                        // Append the undistributed result to outer operands that were not included in
1983                        // indexes_to_undistribute.
1984                        outer_operands.push(undistribute_from);
1985                    }
1986                }
1987            }
1988        }
1989    }
1990
1991    /* #endregion */
1992
1993    /// Adds any columns that *must* be non-Null for `self` to be non-Null.
1994    pub fn non_null_requirements(&self, columns: &mut BTreeSet<usize>) {
1995        match self {
1996            MirScalarExpr::Column(col, _name) => {
1997                columns.insert(*col);
1998            }
1999            MirScalarExpr::Literal(..) => {}
2000            MirScalarExpr::CallUnmaterializable(_) => (),
2001            MirScalarExpr::CallUnary { func, expr } => {
2002                if func.propagates_nulls() {
2003                    expr.non_null_requirements(columns);
2004                }
2005            }
2006            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2007                if func.propagates_nulls() {
2008                    expr1.non_null_requirements(columns);
2009                    expr2.non_null_requirements(columns);
2010                }
2011            }
2012            MirScalarExpr::CallVariadic { func, exprs } => {
2013                if func.propagates_nulls() {
2014                    for expr in exprs {
2015                        expr.non_null_requirements(columns);
2016                    }
2017                }
2018            }
2019            MirScalarExpr::If { .. } => (),
2020        }
2021    }
2022
2023    pub fn typ(&self, column_types: &[ColumnType]) -> ColumnType {
2024        match self {
2025            MirScalarExpr::Column(i, _name) => column_types[*i].clone(),
2026            MirScalarExpr::Literal(_, typ) => typ.clone(),
2027            MirScalarExpr::CallUnmaterializable(func) => func.output_type(),
2028            MirScalarExpr::CallUnary { expr, func } => func.output_type(expr.typ(column_types)),
2029            MirScalarExpr::CallBinary { expr1, expr2, func } => {
2030                func.output_type(expr1.typ(column_types), expr2.typ(column_types))
2031            }
2032            MirScalarExpr::CallVariadic { exprs, func } => {
2033                func.output_type(exprs.iter().map(|e| e.typ(column_types)).collect())
2034            }
2035            MirScalarExpr::If { cond: _, then, els } => {
2036                let then_type = then.typ(column_types);
2037                let else_type = els.typ(column_types);
2038                then_type.union(&else_type).unwrap()
2039            }
2040        }
2041    }
2042
2043    pub fn eval<'a>(
2044        &'a self,
2045        datums: &[Datum<'a>],
2046        temp_storage: &'a RowArena,
2047    ) -> Result<Datum<'a>, EvalError> {
2048        match self {
2049            MirScalarExpr::Column(index, _name) => Ok(datums[*index].clone()),
2050            MirScalarExpr::Literal(res, _column_type) => match res {
2051                Ok(row) => Ok(row.unpack_first()),
2052                Err(e) => Err(e.clone()),
2053            },
2054            // Unmaterializable functions must be transformed away before
2055            // evaluation. Their purpose is as a placeholder for data that is
2056            // not known at plan time but can be inlined before runtime.
2057            MirScalarExpr::CallUnmaterializable(x) => Err(EvalError::Internal(
2058                format!("cannot evaluate unmaterializable function: {:?}", x).into(),
2059            )),
2060            MirScalarExpr::CallUnary { func, expr } => func.eval(datums, temp_storage, expr),
2061            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2062                func.eval(datums, temp_storage, expr1, expr2)
2063            }
2064            MirScalarExpr::CallVariadic { func, exprs } => func.eval(datums, temp_storage, exprs),
2065            MirScalarExpr::If { cond, then, els } => match cond.eval(datums, temp_storage)? {
2066                Datum::True => then.eval(datums, temp_storage),
2067                Datum::False | Datum::Null => els.eval(datums, temp_storage),
2068                d => Err(EvalError::Internal(
2069                    format!("if condition evaluated to non-boolean datum: {:?}", d).into(),
2070                )),
2071            },
2072        }
2073    }
2074
2075    /// True iff the expression contains
2076    /// `UnmaterializableFunc::MzNow`.
2077    pub fn contains_temporal(&self) -> bool {
2078        let mut contains = false;
2079        self.visit_pre(|e| {
2080            if let MirScalarExpr::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2081                contains = true;
2082            }
2083        });
2084        contains
2085    }
2086
2087    /// True iff the expression contains an `UnmaterializableFunc`.
2088    pub fn contains_unmaterializable(&self) -> bool {
2089        let mut contains = false;
2090        self.visit_pre(|e| {
2091            if let MirScalarExpr::CallUnmaterializable(_) = e {
2092                contains = true;
2093            }
2094        });
2095        contains
2096    }
2097
2098    /// True iff the expression contains an `UnmaterializableFunc` that is not in the `exceptions`
2099    /// list.
2100    pub fn contains_unmaterializable_except(&self, exceptions: &[UnmaterializableFunc]) -> bool {
2101        let mut contains = false;
2102        self.visit_pre(|e| match e {
2103            MirScalarExpr::CallUnmaterializable(f) if !exceptions.contains(f) => contains = true,
2104            _ => (),
2105        });
2106        contains
2107    }
2108
2109    /// True iff the expression contains a `Column`.
2110    pub fn contains_column(&self) -> bool {
2111        let mut contains = false;
2112        self.visit_pre(|e| {
2113            if let MirScalarExpr::Column(_col, _name) = e {
2114                contains = true;
2115            }
2116        });
2117        contains
2118    }
2119
2120    /// True iff the expression contains a `Dummy`.
2121    pub fn contains_dummy(&self) -> bool {
2122        let mut contains = false;
2123        self.visit_pre(|e| {
2124            if let MirScalarExpr::Literal(row, _) = e {
2125                if let Ok(row) = row {
2126                    contains |= row.iter().any(|d| d == Datum::Dummy);
2127                }
2128            }
2129        });
2130        contains
2131    }
2132
2133    /// The size of the expression as a tree.
2134    pub fn size(&self) -> usize {
2135        let mut size = 0;
2136        self.visit_pre(&mut |_: &MirScalarExpr| {
2137            size += 1;
2138        });
2139        size
2140    }
2141}
2142
2143impl MirScalarExpr {
2144    /// True iff evaluation could possibly error on non-error input `Datum`.
2145    pub fn could_error(&self) -> bool {
2146        match self {
2147            MirScalarExpr::Column(_col, _name) => false,
2148            MirScalarExpr::Literal(row, ..) => row.is_err(),
2149            MirScalarExpr::CallUnmaterializable(_) => true,
2150            MirScalarExpr::CallUnary { func, expr } => func.could_error() || expr.could_error(),
2151            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
2152                func.could_error() || expr1.could_error() || expr2.could_error()
2153            }
2154            MirScalarExpr::CallVariadic { func, exprs } => {
2155                func.could_error() || exprs.iter().any(|e| e.could_error())
2156            }
2157            MirScalarExpr::If { cond, then, els } => {
2158                cond.could_error() || then.could_error() || els.could_error()
2159            }
2160        }
2161    }
2162}
2163
2164impl VisitChildren<Self> for MirScalarExpr {
2165    fn visit_children<F>(&self, mut f: F)
2166    where
2167        F: FnMut(&Self),
2168    {
2169        use MirScalarExpr::*;
2170        match self {
2171            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2172            CallUnary { expr, .. } => {
2173                f(expr);
2174            }
2175            CallBinary { expr1, expr2, .. } => {
2176                f(expr1);
2177                f(expr2);
2178            }
2179            CallVariadic { exprs, .. } => {
2180                for expr in exprs {
2181                    f(expr);
2182                }
2183            }
2184            If { cond, then, els } => {
2185                f(cond);
2186                f(then);
2187                f(els);
2188            }
2189        }
2190    }
2191
2192    fn visit_mut_children<F>(&mut self, mut f: F)
2193    where
2194        F: FnMut(&mut Self),
2195    {
2196        use MirScalarExpr::*;
2197        match self {
2198            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2199            CallUnary { expr, .. } => {
2200                f(expr);
2201            }
2202            CallBinary { expr1, expr2, .. } => {
2203                f(expr1);
2204                f(expr2);
2205            }
2206            CallVariadic { exprs, .. } => {
2207                for expr in exprs {
2208                    f(expr);
2209                }
2210            }
2211            If { cond, then, els } => {
2212                f(cond);
2213                f(then);
2214                f(els);
2215            }
2216        }
2217    }
2218
2219    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2220    where
2221        F: FnMut(&Self) -> Result<(), E>,
2222        E: From<RecursionLimitError>,
2223    {
2224        use MirScalarExpr::*;
2225        match self {
2226            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2227            CallUnary { expr, .. } => {
2228                f(expr)?;
2229            }
2230            CallBinary { expr1, expr2, .. } => {
2231                f(expr1)?;
2232                f(expr2)?;
2233            }
2234            CallVariadic { exprs, .. } => {
2235                for expr in exprs {
2236                    f(expr)?;
2237                }
2238            }
2239            If { cond, then, els } => {
2240                f(cond)?;
2241                f(then)?;
2242                f(els)?;
2243            }
2244        }
2245        Ok(())
2246    }
2247
2248    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2249    where
2250        F: FnMut(&mut Self) -> Result<(), E>,
2251        E: From<RecursionLimitError>,
2252    {
2253        use MirScalarExpr::*;
2254        match self {
2255            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2256            CallUnary { expr, .. } => {
2257                f(expr)?;
2258            }
2259            CallBinary { expr1, expr2, .. } => {
2260                f(expr1)?;
2261                f(expr2)?;
2262            }
2263            CallVariadic { exprs, .. } => {
2264                for expr in exprs {
2265                    f(expr)?;
2266                }
2267            }
2268            If { cond, then, els } => {
2269                f(cond)?;
2270                f(then)?;
2271                f(els)?;
2272            }
2273        }
2274        Ok(())
2275    }
2276}
2277
2278impl MirScalarExpr {
2279    /// Iterates through references to child expressions.
2280    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2281        let mut first = None;
2282        let mut second = None;
2283        let mut third = None;
2284        let mut variadic = None;
2285
2286        use MirScalarExpr::*;
2287        match self {
2288            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2289            CallUnary { expr, .. } => {
2290                first = Some(&**expr);
2291            }
2292            CallBinary { expr1, expr2, .. } => {
2293                first = Some(&**expr1);
2294                second = Some(&**expr2);
2295            }
2296            CallVariadic { exprs, .. } => {
2297                variadic = Some(exprs);
2298            }
2299            If { cond, then, els } => {
2300                first = Some(&**cond);
2301                second = Some(&**then);
2302                third = Some(&**els);
2303            }
2304        }
2305
2306        first
2307            .into_iter()
2308            .chain(second)
2309            .chain(third)
2310            .chain(variadic.into_iter().flatten())
2311    }
2312
2313    /// Iterates through mutable references to child expressions.
2314    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2315        let mut first = None;
2316        let mut second = None;
2317        let mut third = None;
2318        let mut variadic = None;
2319
2320        use MirScalarExpr::*;
2321        match self {
2322            Column(_, _) | Literal(_, _) | CallUnmaterializable(_) => (),
2323            CallUnary { expr, .. } => {
2324                first = Some(&mut **expr);
2325            }
2326            CallBinary { expr1, expr2, .. } => {
2327                first = Some(&mut **expr1);
2328                second = Some(&mut **expr2);
2329            }
2330            CallVariadic { exprs, .. } => {
2331                variadic = Some(exprs);
2332            }
2333            If { cond, then, els } => {
2334                first = Some(&mut **cond);
2335                second = Some(&mut **then);
2336                third = Some(&mut **els);
2337            }
2338        }
2339
2340        first
2341            .into_iter()
2342            .chain(second)
2343            .chain(third)
2344            .chain(variadic.into_iter().flatten())
2345    }
2346
2347    /// Visits all subexpressions in DFS preorder.
2348    pub fn visit_pre<F>(&self, mut f: F)
2349    where
2350        F: FnMut(&Self),
2351    {
2352        let mut worklist = vec![self];
2353        while let Some(e) = worklist.pop() {
2354            f(e);
2355            worklist.extend(e.children().rev());
2356        }
2357    }
2358
2359    /// Iterative pre-order visitor.
2360    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2361        let mut worklist = vec![self];
2362        while let Some(expr) = worklist.pop() {
2363            f(expr);
2364            worklist.extend(expr.children_mut().rev());
2365        }
2366    }
2367}
2368
2369/// Filter characteristics that are used for ordering join inputs.
2370/// This can be created for a `Vec<MirScalarExpr>`, which represents an AND of predicates.
2371///
2372/// The fields are ordered based on heuristic assumptions about their typical selectivity, so that
2373/// Ord gives the right ordering for join inputs. Bigger is better, i.e., will tend to come earlier
2374/// than other inputs.
2375#[derive(
2376    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
2377)]
2378pub struct FilterCharacteristics {
2379    // `<expr> = <literal>` appears in the filter.
2380    // Excludes cases where NOT appears anywhere above the literal equality.
2381    literal_equality: bool,
2382    // (Assuming a random string of lower-case characters, `LIKE 'a%'` has a selectivity of 1/26.)
2383    like: bool,
2384    is_null: bool,
2385    // Number of Vec elements that involve inequality predicates. (A BETWEEN is represented as two
2386    // inequality predicates.)
2387    // Excludes cases where NOT appears around the literal inequality.
2388    // Note that for inequality predicates, some databases assume 1/3 selectivity in the absence of
2389    // concrete statistics.
2390    literal_inequality: usize,
2391    /// Any filter, except ones involving `IS NOT NULL`, because those are too common.
2392    /// Can be true by itself, or any other field being true can also make this true.
2393    /// `NOT LIKE` is only in this category.
2394    /// `!=` is only in this category.
2395    /// `NOT (a = b)` is turned into `!=` by `reduce` before us!
2396    any_filter: bool,
2397}
2398
2399impl BitOrAssign for FilterCharacteristics {
2400    fn bitor_assign(&mut self, rhs: Self) {
2401        self.literal_equality |= rhs.literal_equality;
2402        self.like |= rhs.like;
2403        self.is_null |= rhs.is_null;
2404        self.literal_inequality += rhs.literal_inequality;
2405        self.any_filter |= rhs.any_filter;
2406    }
2407}
2408
2409impl FilterCharacteristics {
2410    pub fn none() -> FilterCharacteristics {
2411        FilterCharacteristics {
2412            literal_equality: false,
2413            like: false,
2414            is_null: false,
2415            literal_inequality: 0,
2416            any_filter: false,
2417        }
2418    }
2419
2420    pub fn explain(&self) -> String {
2421        let mut e = "".to_owned();
2422        if self.literal_equality {
2423            e.push_str("e");
2424        }
2425        if self.like {
2426            e.push_str("l");
2427        }
2428        if self.is_null {
2429            e.push_str("n");
2430        }
2431        for _ in 0..self.literal_inequality {
2432            e.push_str("i");
2433        }
2434        if self.any_filter {
2435            e.push_str("f");
2436        }
2437        e
2438    }
2439
2440    pub fn filter_characteristics(
2441        filters: &Vec<MirScalarExpr>,
2442    ) -> Result<FilterCharacteristics, RecursionLimitError> {
2443        let mut literal_equality = false;
2444        let mut like = false;
2445        let mut is_null = false;
2446        let mut literal_inequality = 0;
2447        let mut any_filter = false;
2448        filters.iter().try_for_each(|f| {
2449            let mut literal_inequality_in_current_filter = false;
2450            let mut is_not_null_in_current_filter = false;
2451            f.visit_pre_with_context(
2452                false,
2453                &mut |not_in_parent_chain, expr| {
2454                    not_in_parent_chain
2455                        || matches!(
2456                            expr,
2457                            MirScalarExpr::CallUnary {
2458                                func: UnaryFunc::Not(func::Not),
2459                                ..
2460                            }
2461                        )
2462                },
2463                &mut |not_in_parent_chain, expr| {
2464                    if !not_in_parent_chain {
2465                        if expr.any_expr_eq_literal().is_some() {
2466                            literal_equality = true;
2467                        }
2468                        if expr.any_expr_ineq_literal() {
2469                            literal_inequality_in_current_filter = true;
2470                        }
2471                        if matches!(
2472                            expr,
2473                            MirScalarExpr::CallUnary {
2474                                func: UnaryFunc::IsLikeMatch(_),
2475                                ..
2476                            }
2477                        ) {
2478                            like = true;
2479                        }
2480                    };
2481                    if matches!(
2482                        expr,
2483                        MirScalarExpr::CallUnary {
2484                            func: UnaryFunc::IsNull(crate::func::IsNull),
2485                            ..
2486                        }
2487                    ) {
2488                        if *not_in_parent_chain {
2489                            is_not_null_in_current_filter = true;
2490                        } else {
2491                            is_null = true;
2492                        }
2493                    }
2494                },
2495            )?;
2496            if literal_inequality_in_current_filter {
2497                literal_inequality += 1;
2498            }
2499            if !is_not_null_in_current_filter {
2500                // We want to ignore `IS NOT NULL` for `any_filter`.
2501                any_filter = true;
2502            }
2503            Ok(())
2504        })?;
2505        Ok(FilterCharacteristics {
2506            literal_equality,
2507            like,
2508            is_null,
2509            literal_inequality,
2510            any_filter,
2511        })
2512    }
2513
2514    pub fn add_literal_equality(&mut self) {
2515        self.literal_equality = true;
2516    }
2517
2518    pub fn worst_case_scaling_factor(&self) -> f64 {
2519        let mut factor = 1.0;
2520
2521        if self.literal_equality {
2522            factor *= 0.1;
2523        }
2524
2525        if self.is_null {
2526            factor *= 0.1;
2527        }
2528
2529        if self.literal_inequality >= 2 {
2530            factor *= 0.25;
2531        } else if self.literal_inequality == 1 {
2532            factor *= 0.33;
2533        }
2534
2535        // catch various negated filters, treat them pessimistically
2536        if !(self.literal_equality || self.is_null || self.literal_inequality > 0)
2537            && self.any_filter
2538        {
2539            factor *= 0.9;
2540        }
2541
2542        factor
2543    }
2544}
2545
2546#[derive(
2547    Arbitrary,
2548    Ord,
2549    PartialOrd,
2550    Copy,
2551    Clone,
2552    Debug,
2553    Eq,
2554    PartialEq,
2555    Serialize,
2556    Deserialize,
2557    Hash,
2558    MzReflect,
2559)]
2560pub enum DomainLimit {
2561    None,
2562    Inclusive(i64),
2563    Exclusive(i64),
2564}
2565
2566impl RustType<ProtoDomainLimit> for DomainLimit {
2567    fn into_proto(&self) -> ProtoDomainLimit {
2568        use proto_domain_limit::Kind::*;
2569        let kind = match self {
2570            DomainLimit::None => None(()),
2571            DomainLimit::Inclusive(v) => Inclusive(*v),
2572            DomainLimit::Exclusive(v) => Exclusive(*v),
2573        };
2574        ProtoDomainLimit { kind: Some(kind) }
2575    }
2576
2577    fn from_proto(proto: ProtoDomainLimit) -> Result<Self, TryFromProtoError> {
2578        use proto_domain_limit::Kind::*;
2579        if let Some(kind) = proto.kind {
2580            match kind {
2581                None(()) => Ok(DomainLimit::None),
2582                Inclusive(v) => Ok(DomainLimit::Inclusive(v)),
2583                Exclusive(v) => Ok(DomainLimit::Exclusive(v)),
2584            }
2585        } else {
2586            Err(TryFromProtoError::missing_field("ProtoDomainLimit::kind"))
2587        }
2588    }
2589}
2590
2591#[derive(
2592    Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect,
2593)]
2594pub enum EvalError {
2595    CharacterNotValidForEncoding(i32),
2596    CharacterTooLargeForEncoding(i32),
2597    DateBinOutOfRange(Box<str>),
2598    DivisionByZero,
2599    Unsupported {
2600        feature: Box<str>,
2601        discussion_no: Option<usize>,
2602    },
2603    FloatOverflow,
2604    FloatUnderflow,
2605    NumericFieldOverflow,
2606    Float32OutOfRange(Box<str>),
2607    Float64OutOfRange(Box<str>),
2608    Int16OutOfRange(Box<str>),
2609    Int32OutOfRange(Box<str>),
2610    Int64OutOfRange(Box<str>),
2611    UInt16OutOfRange(Box<str>),
2612    UInt32OutOfRange(Box<str>),
2613    UInt64OutOfRange(Box<str>),
2614    MzTimestampOutOfRange(Box<str>),
2615    MzTimestampStepOverflow,
2616    OidOutOfRange(Box<str>),
2617    IntervalOutOfRange(Box<str>),
2618    TimestampCannotBeNan,
2619    TimestampOutOfRange,
2620    DateOutOfRange,
2621    CharOutOfRange,
2622    IndexOutOfRange {
2623        provided: i32,
2624        // The last valid index position, i.e. `v.len() - 1`
2625        valid_end: i32,
2626    },
2627    InvalidBase64Equals,
2628    InvalidBase64Symbol(char),
2629    InvalidBase64EndSequence,
2630    InvalidTimezone(Box<str>),
2631    InvalidTimezoneInterval,
2632    InvalidTimezoneConversion,
2633    InvalidIanaTimezoneId(Box<str>),
2634    InvalidLayer {
2635        max_layer: usize,
2636        val: i64,
2637    },
2638    InvalidArray(InvalidArrayError),
2639    InvalidEncodingName(Box<str>),
2640    InvalidHashAlgorithm(Box<str>),
2641    InvalidByteSequence {
2642        byte_sequence: Box<str>,
2643        encoding_name: Box<str>,
2644    },
2645    InvalidJsonbCast {
2646        from: Box<str>,
2647        to: Box<str>,
2648    },
2649    InvalidRegex(Box<str>),
2650    InvalidRegexFlag(char),
2651    InvalidParameterValue(Box<str>),
2652    InvalidDatePart(Box<str>),
2653    KeyCannotBeNull,
2654    NegSqrt,
2655    NegLimit,
2656    NullCharacterNotPermitted,
2657    UnknownUnits(Box<str>),
2658    UnsupportedUnits(Box<str>, Box<str>),
2659    UnterminatedLikeEscapeSequence,
2660    Parse(ParseError),
2661    ParseHex(ParseHexError),
2662    Internal(Box<str>),
2663    InfinityOutOfDomain(Box<str>),
2664    NegativeOutOfDomain(Box<str>),
2665    ZeroOutOfDomain(Box<str>),
2666    OutOfDomain(DomainLimit, DomainLimit, Box<str>),
2667    ComplexOutOfRange(Box<str>),
2668    MultipleRowsFromSubquery,
2669    Undefined(Box<str>),
2670    LikePatternTooLong,
2671    LikeEscapeTooLong,
2672    StringValueTooLong {
2673        target_type: Box<str>,
2674        length: usize,
2675    },
2676    MultidimensionalArrayRemovalNotSupported,
2677    IncompatibleArrayDimensions {
2678        dims: Option<(usize, usize)>,
2679    },
2680    TypeFromOid(Box<str>),
2681    InvalidRange(InvalidRangeError),
2682    InvalidRoleId(Box<str>),
2683    InvalidPrivileges(Box<str>),
2684    LetRecLimitExceeded(Box<str>),
2685    MultiDimensionalArraySearch,
2686    MustNotBeNull(Box<str>),
2687    InvalidIdentifier {
2688        ident: Box<str>,
2689        detail: Option<Box<str>>,
2690    },
2691    ArrayFillWrongArraySubscripts,
2692    // TODO: propagate this check more widely throughout the expr crate
2693    MaxArraySizeExceeded(usize),
2694    DateDiffOverflow {
2695        unit: Box<str>,
2696        a: Box<str>,
2697        b: Box<str>,
2698    },
2699    // The error for ErrorIfNull; this should not be used in other contexts as a generic error
2700    // printer.
2701    IfNullError(Box<str>),
2702    LengthTooLarge,
2703    AclArrayNullElement,
2704    MzAclArrayNullElement,
2705    PrettyError(Box<str>),
2706}
2707
2708impl fmt::Display for EvalError {
2709    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2710        match self {
2711            EvalError::CharacterNotValidForEncoding(v) => {
2712                write!(f, "requested character not valid for encoding: {v}")
2713            }
2714            EvalError::CharacterTooLargeForEncoding(v) => {
2715                write!(f, "requested character too large for encoding: {v}")
2716            }
2717            EvalError::DateBinOutOfRange(message) => f.write_str(message),
2718            EvalError::DivisionByZero => f.write_str("division by zero"),
2719            EvalError::Unsupported {
2720                feature,
2721                discussion_no,
2722            } => {
2723                write!(f, "{} not yet supported", feature)?;
2724                if let Some(discussion_no) = discussion_no {
2725                    write!(
2726                        f,
2727                        ", see https://github.com/MaterializeInc/materialize/discussions/{} for more details",
2728                        discussion_no
2729                    )?;
2730                }
2731                Ok(())
2732            }
2733            EvalError::FloatOverflow => f.write_str("value out of range: overflow"),
2734            EvalError::FloatUnderflow => f.write_str("value out of range: underflow"),
2735            EvalError::NumericFieldOverflow => f.write_str("numeric field overflow"),
2736            EvalError::Float32OutOfRange(val) => write!(f, "{} real out of range", val.quoted()),
2737            EvalError::Float64OutOfRange(val) => {
2738                write!(f, "{} double precision out of range", val.quoted())
2739            }
2740            EvalError::Int16OutOfRange(val) => write!(f, "{} smallint out of range", val.quoted()),
2741            EvalError::Int32OutOfRange(val) => write!(f, "{} integer out of range", val.quoted()),
2742            EvalError::Int64OutOfRange(val) => write!(f, "{} bigint out of range", val.quoted()),
2743            EvalError::UInt16OutOfRange(val) => write!(f, "{} uint2 out of range", val.quoted()),
2744            EvalError::UInt32OutOfRange(val) => write!(f, "{} uint4 out of range", val.quoted()),
2745            EvalError::UInt64OutOfRange(val) => write!(f, "{} uint8 out of range", val.quoted()),
2746            EvalError::MzTimestampOutOfRange(val) => {
2747                write!(f, "{} mz_timestamp out of range", val.quoted())
2748            }
2749            EvalError::MzTimestampStepOverflow => f.write_str("step mz_timestamp overflow"),
2750            EvalError::OidOutOfRange(val) => write!(f, "{} OID out of range", val.quoted()),
2751            EvalError::IntervalOutOfRange(val) => {
2752                write!(f, "{} interval out of range", val.quoted())
2753            }
2754            EvalError::TimestampCannotBeNan => f.write_str("timestamp cannot be NaN"),
2755            EvalError::TimestampOutOfRange => f.write_str("timestamp out of range"),
2756            EvalError::DateOutOfRange => f.write_str("date out of range"),
2757            EvalError::CharOutOfRange => f.write_str("\"char\" out of range"),
2758            EvalError::IndexOutOfRange {
2759                provided,
2760                valid_end,
2761            } => write!(f, "index {provided} out of valid range, 0..{valid_end}",),
2762            EvalError::InvalidBase64Equals => {
2763                f.write_str("unexpected \"=\" while decoding base64 sequence")
2764            }
2765            EvalError::InvalidBase64Symbol(c) => write!(
2766                f,
2767                "invalid symbol \"{}\" found while decoding base64 sequence",
2768                c.escape_default()
2769            ),
2770            EvalError::InvalidBase64EndSequence => f.write_str("invalid base64 end sequence"),
2771            EvalError::InvalidJsonbCast { from, to } => {
2772                write!(f, "cannot cast jsonb {} to type {}", from, to)
2773            }
2774            EvalError::InvalidTimezone(tz) => write!(f, "invalid time zone '{}'", tz),
2775            EvalError::InvalidTimezoneInterval => {
2776                f.write_str("timezone interval must not contain months or years")
2777            }
2778            EvalError::InvalidTimezoneConversion => f.write_str("invalid timezone conversion"),
2779            EvalError::InvalidIanaTimezoneId(tz) => {
2780                write!(f, "invalid IANA Time Zone Database identifier: '{}'", tz)
2781            }
2782            EvalError::InvalidLayer { max_layer, val } => write!(
2783                f,
2784                "invalid layer: {}; must use value within [1, {}]",
2785                val, max_layer
2786            ),
2787            EvalError::InvalidArray(e) => e.fmt(f),
2788            EvalError::InvalidEncodingName(name) => write!(f, "invalid encoding name '{}'", name),
2789            EvalError::InvalidHashAlgorithm(alg) => write!(f, "invalid hash algorithm '{}'", alg),
2790            EvalError::InvalidByteSequence {
2791                byte_sequence,
2792                encoding_name,
2793            } => write!(
2794                f,
2795                "invalid byte sequence '{}' for encoding '{}'",
2796                byte_sequence, encoding_name
2797            ),
2798            EvalError::InvalidDatePart(part) => write!(f, "invalid datepart {}", part.quoted()),
2799            EvalError::KeyCannotBeNull => f.write_str("key cannot be null"),
2800            EvalError::NegSqrt => f.write_str("cannot take square root of a negative number"),
2801            EvalError::NegLimit => f.write_str("LIMIT must not be negative"),
2802            EvalError::NullCharacterNotPermitted => f.write_str("null character not permitted"),
2803            EvalError::InvalidRegex(e) => write!(f, "invalid regular expression: {}", e),
2804            EvalError::InvalidRegexFlag(c) => write!(f, "invalid regular expression flag: {}", c),
2805            EvalError::InvalidParameterValue(s) => f.write_str(s),
2806            EvalError::UnknownUnits(units) => write!(f, "unit '{}' not recognized", units),
2807            EvalError::UnsupportedUnits(units, typ) => {
2808                write!(f, "unit '{}' not supported for type {}", units, typ)
2809            }
2810            EvalError::UnterminatedLikeEscapeSequence => {
2811                f.write_str("unterminated escape sequence in LIKE")
2812            }
2813            EvalError::Parse(e) => e.fmt(f),
2814            EvalError::PrettyError(e) => e.fmt(f),
2815            EvalError::ParseHex(e) => e.fmt(f),
2816            EvalError::Internal(s) => write!(f, "internal error: {}", s),
2817            EvalError::InfinityOutOfDomain(s) => {
2818                write!(f, "function {} is only defined for finite arguments", s)
2819            }
2820            EvalError::NegativeOutOfDomain(s) => {
2821                write!(f, "function {} is not defined for negative numbers", s)
2822            }
2823            EvalError::ZeroOutOfDomain(s) => {
2824                write!(f, "function {} is not defined for zero", s)
2825            }
2826            EvalError::OutOfDomain(lower, upper, s) => {
2827                use DomainLimit::*;
2828                write!(f, "function {s} is defined for numbers ")?;
2829                match (lower, upper) {
2830                    (Inclusive(n), None) => write!(f, "greater than or equal to {n}"),
2831                    (Exclusive(n), None) => write!(f, "greater than {n}"),
2832                    (None, Inclusive(n)) => write!(f, "less than or equal to {n}"),
2833                    (None, Exclusive(n)) => write!(f, "less than {n}"),
2834                    (Inclusive(lo), Inclusive(hi)) => write!(f, "between {lo} and {hi} inclusive"),
2835                    (Exclusive(lo), Exclusive(hi)) => write!(f, "between {lo} and {hi} exclusive"),
2836                    (Inclusive(lo), Exclusive(hi)) => {
2837                        write!(f, "between {lo} inclusive and {hi} exclusive")
2838                    }
2839                    (Exclusive(lo), Inclusive(hi)) => {
2840                        write!(f, "between {lo} exclusive and {hi} inclusive")
2841                    }
2842                    (None, None) => panic!("invalid domain error"),
2843                }
2844            }
2845            EvalError::ComplexOutOfRange(s) => {
2846                write!(f, "function {} cannot return complex numbers", s)
2847            }
2848            EvalError::MultipleRowsFromSubquery => {
2849                write!(f, "more than one record produced in subquery")
2850            }
2851            EvalError::Undefined(s) => {
2852                write!(f, "{} is undefined", s)
2853            }
2854            EvalError::LikePatternTooLong => {
2855                write!(f, "LIKE pattern exceeds maximum length")
2856            }
2857            EvalError::LikeEscapeTooLong => {
2858                write!(f, "invalid escape string")
2859            }
2860            EvalError::StringValueTooLong {
2861                target_type,
2862                length,
2863            } => {
2864                write!(f, "value too long for type {}({})", target_type, length)
2865            }
2866            EvalError::MultidimensionalArrayRemovalNotSupported => {
2867                write!(
2868                    f,
2869                    "removing elements from multidimensional arrays is not supported"
2870                )
2871            }
2872            EvalError::IncompatibleArrayDimensions { dims: _ } => {
2873                write!(f, "cannot concatenate incompatible arrays")
2874            }
2875            EvalError::TypeFromOid(msg) => write!(f, "{msg}"),
2876            EvalError::InvalidRange(e) => e.fmt(f),
2877            EvalError::InvalidRoleId(msg) => write!(f, "{msg}"),
2878            EvalError::InvalidPrivileges(privilege) => {
2879                write!(f, "unrecognized privilege type: {privilege}")
2880            }
2881            EvalError::LetRecLimitExceeded(max_iters) => {
2882                write!(
2883                    f,
2884                    "Recursive query exceeded the recursion limit {}. (Use RETURN AT RECURSION LIMIT to not error, but return the current state as the final result when reaching the limit.)",
2885                    max_iters
2886                )
2887            }
2888            EvalError::MultiDimensionalArraySearch => write!(
2889                f,
2890                "searching for elements in multidimensional arrays is not supported"
2891            ),
2892            EvalError::MustNotBeNull(v) => write!(f, "{v} must not be null"),
2893            EvalError::InvalidIdentifier { ident, .. } => {
2894                write!(f, "string is not a valid identifier: {}", ident.quoted())
2895            }
2896            EvalError::ArrayFillWrongArraySubscripts => {
2897                f.write_str("wrong number of array subscripts")
2898            }
2899            EvalError::MaxArraySizeExceeded(max_size) => {
2900                write!(
2901                    f,
2902                    "array size exceeds the maximum allowed ({max_size} bytes)"
2903                )
2904            }
2905            EvalError::DateDiffOverflow { unit, a, b } => {
2906                write!(f, "datediff overflow, {unit} of {a}, {b}")
2907            }
2908            EvalError::IfNullError(s) => f.write_str(s),
2909            EvalError::LengthTooLarge => write!(f, "requested length too large"),
2910            EvalError::AclArrayNullElement => write!(f, "ACL arrays must not contain null values"),
2911            EvalError::MzAclArrayNullElement => {
2912                write!(f, "MZ_ACL arrays must not contain null values")
2913            }
2914        }
2915    }
2916}
2917
2918impl EvalError {
2919    pub fn detail(&self) -> Option<String> {
2920        match self {
2921            EvalError::IncompatibleArrayDimensions { dims: None } => Some(
2922                "Arrays with differing dimensions are not compatible for concatenation.".into(),
2923            ),
2924            EvalError::IncompatibleArrayDimensions {
2925                dims: Some((a_dims, b_dims)),
2926            } => Some(format!(
2927                "Arrays of {} and {} dimensions are not compatible for concatenation.",
2928                a_dims, b_dims
2929            )),
2930            EvalError::InvalidIdentifier { detail, .. } => detail.as_deref().map(Into::into),
2931            EvalError::ArrayFillWrongArraySubscripts => {
2932                Some("Low bound array has different size than dimensions array.".into())
2933            }
2934            _ => None,
2935        }
2936    }
2937
2938    pub fn hint(&self) -> Option<String> {
2939        match self {
2940            EvalError::InvalidBase64EndSequence => Some(
2941                "Input data is missing padding, is truncated, or is otherwise corrupted.".into(),
2942            ),
2943            EvalError::LikeEscapeTooLong => {
2944                Some("Escape string must be empty or one character.".into())
2945            }
2946            EvalError::MzTimestampOutOfRange(_) => Some(
2947                "Integer, numeric, and text casts to mz_timestamp must be in the form of whole \
2948                milliseconds since the Unix epoch. Values with fractional parts cannot be \
2949                converted to mz_timestamp."
2950                    .into(),
2951            ),
2952            _ => None,
2953        }
2954    }
2955}
2956
2957impl std::error::Error for EvalError {}
2958
2959impl From<ParseError> for EvalError {
2960    fn from(e: ParseError) -> EvalError {
2961        EvalError::Parse(e)
2962    }
2963}
2964
2965impl From<ParseHexError> for EvalError {
2966    fn from(e: ParseHexError) -> EvalError {
2967        EvalError::ParseHex(e)
2968    }
2969}
2970
2971impl From<InvalidArrayError> for EvalError {
2972    fn from(e: InvalidArrayError) -> EvalError {
2973        EvalError::InvalidArray(e)
2974    }
2975}
2976
2977impl From<regex::Error> for EvalError {
2978    fn from(e: regex::Error) -> EvalError {
2979        EvalError::InvalidRegex(e.to_string().into())
2980    }
2981}
2982
2983impl From<TypeFromOidError> for EvalError {
2984    fn from(e: TypeFromOidError) -> EvalError {
2985        EvalError::TypeFromOid(e.to_string().into())
2986    }
2987}
2988
2989impl From<DateError> for EvalError {
2990    fn from(e: DateError) -> EvalError {
2991        match e {
2992            DateError::OutOfRange => EvalError::DateOutOfRange,
2993        }
2994    }
2995}
2996
2997impl From<TimestampError> for EvalError {
2998    fn from(e: TimestampError) -> EvalError {
2999        match e {
3000            TimestampError::OutOfRange => EvalError::TimestampOutOfRange,
3001        }
3002    }
3003}
3004
3005impl From<InvalidRangeError> for EvalError {
3006    fn from(e: InvalidRangeError) -> EvalError {
3007        EvalError::InvalidRange(e)
3008    }
3009}
3010
3011impl RustType<ProtoEvalError> for EvalError {
3012    fn into_proto(&self) -> ProtoEvalError {
3013        use proto_eval_error::Kind::*;
3014        use proto_eval_error::*;
3015        let kind = match self {
3016            EvalError::CharacterNotValidForEncoding(v) => CharacterNotValidForEncoding(*v),
3017            EvalError::CharacterTooLargeForEncoding(v) => CharacterTooLargeForEncoding(*v),
3018            EvalError::DateBinOutOfRange(v) => DateBinOutOfRange(v.into_proto()),
3019            EvalError::DivisionByZero => DivisionByZero(()),
3020            EvalError::Unsupported {
3021                feature,
3022                discussion_no,
3023            } => Unsupported(ProtoUnsupported {
3024                feature: feature.into_proto(),
3025                discussion_no: discussion_no.into_proto(),
3026            }),
3027            EvalError::FloatOverflow => FloatOverflow(()),
3028            EvalError::FloatUnderflow => FloatUnderflow(()),
3029            EvalError::NumericFieldOverflow => NumericFieldOverflow(()),
3030            EvalError::Float32OutOfRange(val) => Float32OutOfRange(ProtoValueOutOfRange {
3031                value: val.to_string(),
3032            }),
3033            EvalError::Float64OutOfRange(val) => Float64OutOfRange(ProtoValueOutOfRange {
3034                value: val.to_string(),
3035            }),
3036            EvalError::Int16OutOfRange(val) => Int16OutOfRange(ProtoValueOutOfRange {
3037                value: val.to_string(),
3038            }),
3039            EvalError::Int32OutOfRange(val) => Int32OutOfRange(ProtoValueOutOfRange {
3040                value: val.to_string(),
3041            }),
3042            EvalError::Int64OutOfRange(val) => Int64OutOfRange(ProtoValueOutOfRange {
3043                value: val.to_string(),
3044            }),
3045            EvalError::UInt16OutOfRange(val) => Uint16OutOfRange(ProtoValueOutOfRange {
3046                value: val.to_string(),
3047            }),
3048            EvalError::UInt32OutOfRange(val) => Uint32OutOfRange(ProtoValueOutOfRange {
3049                value: val.to_string(),
3050            }),
3051            EvalError::UInt64OutOfRange(val) => Uint64OutOfRange(ProtoValueOutOfRange {
3052                value: val.to_string(),
3053            }),
3054            EvalError::MzTimestampOutOfRange(val) => MzTimestampOutOfRange(ProtoValueOutOfRange {
3055                value: val.to_string(),
3056            }),
3057            EvalError::MzTimestampStepOverflow => MzTimestampStepOverflow(()),
3058            EvalError::OidOutOfRange(val) => OidOutOfRange(ProtoValueOutOfRange {
3059                value: val.to_string(),
3060            }),
3061            EvalError::IntervalOutOfRange(val) => IntervalOutOfRange(ProtoValueOutOfRange {
3062                value: val.to_string(),
3063            }),
3064            EvalError::TimestampCannotBeNan => TimestampCannotBeNan(()),
3065            EvalError::TimestampOutOfRange => TimestampOutOfRange(()),
3066            EvalError::DateOutOfRange => DateOutOfRange(()),
3067            EvalError::CharOutOfRange => CharOutOfRange(()),
3068            EvalError::IndexOutOfRange {
3069                provided,
3070                valid_end,
3071            } => IndexOutOfRange(ProtoIndexOutOfRange {
3072                provided: *provided,
3073                valid_end: *valid_end,
3074            }),
3075            EvalError::InvalidBase64Equals => InvalidBase64Equals(()),
3076            EvalError::InvalidBase64Symbol(sym) => InvalidBase64Symbol(sym.into_proto()),
3077            EvalError::InvalidBase64EndSequence => InvalidBase64EndSequence(()),
3078            EvalError::InvalidTimezone(tz) => InvalidTimezone(tz.into_proto()),
3079            EvalError::InvalidTimezoneInterval => InvalidTimezoneInterval(()),
3080            EvalError::InvalidTimezoneConversion => InvalidTimezoneConversion(()),
3081            EvalError::InvalidLayer { max_layer, val } => InvalidLayer(ProtoInvalidLayer {
3082                max_layer: max_layer.into_proto(),
3083                val: *val,
3084            }),
3085            EvalError::InvalidArray(error) => InvalidArray(error.into_proto()),
3086            EvalError::InvalidEncodingName(v) => InvalidEncodingName(v.into_proto()),
3087            EvalError::InvalidHashAlgorithm(v) => InvalidHashAlgorithm(v.into_proto()),
3088            EvalError::InvalidByteSequence {
3089                byte_sequence,
3090                encoding_name,
3091            } => InvalidByteSequence(ProtoInvalidByteSequence {
3092                byte_sequence: byte_sequence.into_proto(),
3093                encoding_name: encoding_name.into_proto(),
3094            }),
3095            EvalError::InvalidJsonbCast { from, to } => InvalidJsonbCast(ProtoInvalidJsonbCast {
3096                from: from.into_proto(),
3097                to: to.into_proto(),
3098            }),
3099            EvalError::InvalidRegex(v) => InvalidRegex(v.into_proto()),
3100            EvalError::InvalidRegexFlag(v) => InvalidRegexFlag(v.into_proto()),
3101            EvalError::InvalidParameterValue(v) => InvalidParameterValue(v.into_proto()),
3102            EvalError::InvalidDatePart(part) => InvalidDatePart(part.into_proto()),
3103            EvalError::KeyCannotBeNull => KeyCannotBeNull(()),
3104            EvalError::NegSqrt => NegSqrt(()),
3105            EvalError::NegLimit => NegLimit(()),
3106            EvalError::NullCharacterNotPermitted => NullCharacterNotPermitted(()),
3107            EvalError::UnknownUnits(v) => UnknownUnits(v.into_proto()),
3108            EvalError::UnsupportedUnits(units, typ) => UnsupportedUnits(ProtoUnsupportedUnits {
3109                units: units.into_proto(),
3110                typ: typ.into_proto(),
3111            }),
3112            EvalError::UnterminatedLikeEscapeSequence => UnterminatedLikeEscapeSequence(()),
3113            EvalError::Parse(error) => Parse(error.into_proto()),
3114            EvalError::PrettyError(error) => PrettyError(error.into_proto()),
3115            EvalError::ParseHex(error) => ParseHex(error.into_proto()),
3116            EvalError::Internal(v) => Internal(v.into_proto()),
3117            EvalError::InfinityOutOfDomain(v) => InfinityOutOfDomain(v.into_proto()),
3118            EvalError::NegativeOutOfDomain(v) => NegativeOutOfDomain(v.into_proto()),
3119            EvalError::ZeroOutOfDomain(v) => ZeroOutOfDomain(v.into_proto()),
3120            EvalError::OutOfDomain(lower, upper, id) => OutOfDomain(ProtoOutOfDomain {
3121                lower: Some(lower.into_proto()),
3122                upper: Some(upper.into_proto()),
3123                id: id.into_proto(),
3124            }),
3125            EvalError::ComplexOutOfRange(v) => ComplexOutOfRange(v.into_proto()),
3126            EvalError::MultipleRowsFromSubquery => MultipleRowsFromSubquery(()),
3127            EvalError::Undefined(v) => Undefined(v.into_proto()),
3128            EvalError::LikePatternTooLong => LikePatternTooLong(()),
3129            EvalError::LikeEscapeTooLong => LikeEscapeTooLong(()),
3130            EvalError::StringValueTooLong {
3131                target_type,
3132                length,
3133            } => StringValueTooLong(ProtoStringValueTooLong {
3134                target_type: target_type.into_proto(),
3135                length: length.into_proto(),
3136            }),
3137            EvalError::MultidimensionalArrayRemovalNotSupported => {
3138                MultidimensionalArrayRemovalNotSupported(())
3139            }
3140            EvalError::IncompatibleArrayDimensions { dims } => {
3141                IncompatibleArrayDimensions(ProtoIncompatibleArrayDimensions {
3142                    dims: dims.into_proto(),
3143                })
3144            }
3145            EvalError::TypeFromOid(v) => TypeFromOid(v.into_proto()),
3146            EvalError::InvalidRange(error) => InvalidRange(error.into_proto()),
3147            EvalError::InvalidRoleId(v) => InvalidRoleId(v.into_proto()),
3148            EvalError::InvalidPrivileges(v) => InvalidPrivileges(v.into_proto()),
3149            EvalError::LetRecLimitExceeded(v) => WmrRecursionLimitExceeded(v.into_proto()),
3150            EvalError::MultiDimensionalArraySearch => MultiDimensionalArraySearch(()),
3151            EvalError::MustNotBeNull(v) => MustNotBeNull(v.into_proto()),
3152            EvalError::InvalidIdentifier { ident, detail } => {
3153                InvalidIdentifier(ProtoInvalidIdentifier {
3154                    ident: ident.into_proto(),
3155                    detail: detail.into_proto(),
3156                })
3157            }
3158            EvalError::ArrayFillWrongArraySubscripts => ArrayFillWrongArraySubscripts(()),
3159            EvalError::MaxArraySizeExceeded(max_size) => {
3160                MaxArraySizeExceeded(u64::cast_from(*max_size))
3161            }
3162            EvalError::DateDiffOverflow { unit, a, b } => DateDiffOverflow(ProtoDateDiffOverflow {
3163                unit: unit.into_proto(),
3164                a: a.into_proto(),
3165                b: b.into_proto(),
3166            }),
3167            EvalError::IfNullError(s) => IfNullError(s.into_proto()),
3168            EvalError::LengthTooLarge => LengthTooLarge(()),
3169            EvalError::AclArrayNullElement => AclArrayNullElement(()),
3170            EvalError::MzAclArrayNullElement => MzAclArrayNullElement(()),
3171            EvalError::InvalidIanaTimezoneId(s) => InvalidIanaTimezoneId(s.into_proto()),
3172        };
3173        ProtoEvalError { kind: Some(kind) }
3174    }
3175
3176    fn from_proto(proto: ProtoEvalError) -> Result<Self, TryFromProtoError> {
3177        use proto_eval_error::Kind::*;
3178        match proto.kind {
3179            Some(kind) => match kind {
3180                CharacterNotValidForEncoding(v) => Ok(EvalError::CharacterNotValidForEncoding(v)),
3181                CharacterTooLargeForEncoding(v) => Ok(EvalError::CharacterTooLargeForEncoding(v)),
3182                DateBinOutOfRange(v) => Ok(EvalError::DateBinOutOfRange(v.into())),
3183                DivisionByZero(()) => Ok(EvalError::DivisionByZero),
3184                Unsupported(v) => Ok(EvalError::Unsupported {
3185                    feature: v.feature.into(),
3186                    discussion_no: v.discussion_no.into_rust()?,
3187                }),
3188                FloatOverflow(()) => Ok(EvalError::FloatOverflow),
3189                FloatUnderflow(()) => Ok(EvalError::FloatUnderflow),
3190                NumericFieldOverflow(()) => Ok(EvalError::NumericFieldOverflow),
3191                Float32OutOfRange(val) => Ok(EvalError::Float32OutOfRange(val.value.into())),
3192                Float64OutOfRange(val) => Ok(EvalError::Float64OutOfRange(val.value.into())),
3193                Int16OutOfRange(val) => Ok(EvalError::Int16OutOfRange(val.value.into())),
3194                Int32OutOfRange(val) => Ok(EvalError::Int32OutOfRange(val.value.into())),
3195                Int64OutOfRange(val) => Ok(EvalError::Int64OutOfRange(val.value.into())),
3196                Uint16OutOfRange(val) => Ok(EvalError::UInt16OutOfRange(val.value.into())),
3197                Uint32OutOfRange(val) => Ok(EvalError::UInt32OutOfRange(val.value.into())),
3198                Uint64OutOfRange(val) => Ok(EvalError::UInt64OutOfRange(val.value.into())),
3199                MzTimestampOutOfRange(val) => {
3200                    Ok(EvalError::MzTimestampOutOfRange(val.value.into()))
3201                }
3202                MzTimestampStepOverflow(()) => Ok(EvalError::MzTimestampStepOverflow),
3203                OidOutOfRange(val) => Ok(EvalError::OidOutOfRange(val.value.into())),
3204                IntervalOutOfRange(val) => Ok(EvalError::IntervalOutOfRange(val.value.into())),
3205                TimestampCannotBeNan(()) => Ok(EvalError::TimestampCannotBeNan),
3206                TimestampOutOfRange(()) => Ok(EvalError::TimestampOutOfRange),
3207                DateOutOfRange(()) => Ok(EvalError::DateOutOfRange),
3208                CharOutOfRange(()) => Ok(EvalError::CharOutOfRange),
3209                IndexOutOfRange(v) => Ok(EvalError::IndexOutOfRange {
3210                    provided: v.provided,
3211                    valid_end: v.valid_end,
3212                }),
3213                InvalidBase64Equals(()) => Ok(EvalError::InvalidBase64Equals),
3214                InvalidBase64Symbol(v) => char::from_proto(v).map(EvalError::InvalidBase64Symbol),
3215                InvalidBase64EndSequence(()) => Ok(EvalError::InvalidBase64EndSequence),
3216                InvalidTimezone(v) => Ok(EvalError::InvalidTimezone(v.into())),
3217                InvalidTimezoneInterval(()) => Ok(EvalError::InvalidTimezoneInterval),
3218                InvalidTimezoneConversion(()) => Ok(EvalError::InvalidTimezoneConversion),
3219                InvalidLayer(v) => Ok(EvalError::InvalidLayer {
3220                    max_layer: usize::from_proto(v.max_layer)?,
3221                    val: v.val,
3222                }),
3223                InvalidArray(error) => Ok(EvalError::InvalidArray(error.into_rust()?)),
3224                InvalidEncodingName(v) => Ok(EvalError::InvalidEncodingName(v.into())),
3225                InvalidHashAlgorithm(v) => Ok(EvalError::InvalidHashAlgorithm(v.into())),
3226                InvalidByteSequence(v) => Ok(EvalError::InvalidByteSequence {
3227                    byte_sequence: v.byte_sequence.into(),
3228                    encoding_name: v.encoding_name.into(),
3229                }),
3230                InvalidJsonbCast(v) => Ok(EvalError::InvalidJsonbCast {
3231                    from: v.from.into(),
3232                    to: v.to.into(),
3233                }),
3234                InvalidRegex(v) => Ok(EvalError::InvalidRegex(v.into())),
3235                InvalidRegexFlag(v) => Ok(EvalError::InvalidRegexFlag(char::from_proto(v)?)),
3236                InvalidParameterValue(v) => Ok(EvalError::InvalidParameterValue(v.into())),
3237                InvalidDatePart(part) => Ok(EvalError::InvalidDatePart(part.into())),
3238                KeyCannotBeNull(()) => Ok(EvalError::KeyCannotBeNull),
3239                NegSqrt(()) => Ok(EvalError::NegSqrt),
3240                NegLimit(()) => Ok(EvalError::NegLimit),
3241                NullCharacterNotPermitted(()) => Ok(EvalError::NullCharacterNotPermitted),
3242                UnknownUnits(v) => Ok(EvalError::UnknownUnits(v.into())),
3243                UnsupportedUnits(v) => {
3244                    Ok(EvalError::UnsupportedUnits(v.units.into(), v.typ.into()))
3245                }
3246                UnterminatedLikeEscapeSequence(()) => Ok(EvalError::UnterminatedLikeEscapeSequence),
3247                Parse(error) => Ok(EvalError::Parse(error.into_rust()?)),
3248                ParseHex(error) => Ok(EvalError::ParseHex(error.into_rust()?)),
3249                Internal(v) => Ok(EvalError::Internal(v.into())),
3250                InfinityOutOfDomain(v) => Ok(EvalError::InfinityOutOfDomain(v.into())),
3251                NegativeOutOfDomain(v) => Ok(EvalError::NegativeOutOfDomain(v.into())),
3252                ZeroOutOfDomain(v) => Ok(EvalError::ZeroOutOfDomain(v.into())),
3253                OutOfDomain(v) => Ok(EvalError::OutOfDomain(
3254                    v.lower.into_rust_if_some("ProtoDomainLimit::lower")?,
3255                    v.upper.into_rust_if_some("ProtoDomainLimit::upper")?,
3256                    v.id.into(),
3257                )),
3258                ComplexOutOfRange(v) => Ok(EvalError::ComplexOutOfRange(v.into())),
3259                MultipleRowsFromSubquery(()) => Ok(EvalError::MultipleRowsFromSubquery),
3260                Undefined(v) => Ok(EvalError::Undefined(v.into())),
3261                LikePatternTooLong(()) => Ok(EvalError::LikePatternTooLong),
3262                LikeEscapeTooLong(()) => Ok(EvalError::LikeEscapeTooLong),
3263                StringValueTooLong(v) => Ok(EvalError::StringValueTooLong {
3264                    target_type: v.target_type.into(),
3265                    length: usize::from_proto(v.length)?,
3266                }),
3267                MultidimensionalArrayRemovalNotSupported(()) => {
3268                    Ok(EvalError::MultidimensionalArrayRemovalNotSupported)
3269                }
3270                IncompatibleArrayDimensions(v) => Ok(EvalError::IncompatibleArrayDimensions {
3271                    dims: v.dims.into_rust()?,
3272                }),
3273                TypeFromOid(v) => Ok(EvalError::TypeFromOid(v.into())),
3274                InvalidRange(e) => Ok(EvalError::InvalidRange(e.into_rust()?)),
3275                InvalidRoleId(v) => Ok(EvalError::InvalidRoleId(v.into())),
3276                InvalidPrivileges(v) => Ok(EvalError::InvalidPrivileges(v.into())),
3277                WmrRecursionLimitExceeded(v) => Ok(EvalError::LetRecLimitExceeded(v.into())),
3278                MultiDimensionalArraySearch(()) => Ok(EvalError::MultiDimensionalArraySearch),
3279                MustNotBeNull(v) => Ok(EvalError::MustNotBeNull(v.into())),
3280                InvalidIdentifier(v) => Ok(EvalError::InvalidIdentifier {
3281                    ident: v.ident.into(),
3282                    detail: v.detail.into_rust()?,
3283                }),
3284                ArrayFillWrongArraySubscripts(()) => Ok(EvalError::ArrayFillWrongArraySubscripts),
3285                MaxArraySizeExceeded(max_size) => {
3286                    Ok(EvalError::MaxArraySizeExceeded(usize::cast_from(max_size)))
3287                }
3288                DateDiffOverflow(v) => Ok(EvalError::DateDiffOverflow {
3289                    unit: v.unit.into(),
3290                    a: v.a.into(),
3291                    b: v.b.into(),
3292                }),
3293                IfNullError(v) => Ok(EvalError::IfNullError(v.into())),
3294                LengthTooLarge(()) => Ok(EvalError::LengthTooLarge),
3295                AclArrayNullElement(()) => Ok(EvalError::AclArrayNullElement),
3296                MzAclArrayNullElement(()) => Ok(EvalError::MzAclArrayNullElement),
3297                InvalidIanaTimezoneId(s) => Ok(EvalError::InvalidIanaTimezoneId(s.into())),
3298                PrettyError(s) => Ok(EvalError::PrettyError(s.into())),
3299            },
3300            None => Err(TryFromProtoError::missing_field("ProtoEvalError::kind")),
3301        }
3302    }
3303}
3304
3305impl RustType<ProtoDims> for (usize, usize) {
3306    fn into_proto(&self) -> ProtoDims {
3307        ProtoDims {
3308            f0: self.0.into_proto(),
3309            f1: self.1.into_proto(),
3310        }
3311    }
3312
3313    fn from_proto(proto: ProtoDims) -> Result<Self, TryFromProtoError> {
3314        Ok((proto.f0.into_rust()?, proto.f1.into_rust()?))
3315    }
3316}
3317
3318#[cfg(test)]
3319mod tests {
3320    use mz_ore::assert_ok;
3321    use mz_proto::protobuf_roundtrip;
3322
3323    use super::*;
3324
3325    #[mz_ore::test]
3326    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
3327    fn test_reduce() {
3328        let relation_type = vec![
3329            ScalarType::Int64.nullable(true),
3330            ScalarType::Int64.nullable(true),
3331            ScalarType::Int64.nullable(false),
3332        ];
3333        let col = MirScalarExpr::column;
3334        let err = |e| MirScalarExpr::literal(Err(e), ScalarType::Int64);
3335        let lit = |i| MirScalarExpr::literal_ok(Datum::Int64(i), ScalarType::Int64);
3336        let null = || MirScalarExpr::literal_null(ScalarType::Int64);
3337
3338        struct TestCase {
3339            input: MirScalarExpr,
3340            output: MirScalarExpr,
3341        }
3342
3343        let test_cases = vec![
3344            TestCase {
3345                input: MirScalarExpr::CallVariadic {
3346                    func: VariadicFunc::Coalesce,
3347                    exprs: vec![lit(1)],
3348                },
3349                output: lit(1),
3350            },
3351            TestCase {
3352                input: MirScalarExpr::CallVariadic {
3353                    func: VariadicFunc::Coalesce,
3354                    exprs: vec![lit(1), lit(2)],
3355                },
3356                output: lit(1),
3357            },
3358            TestCase {
3359                input: MirScalarExpr::CallVariadic {
3360                    func: VariadicFunc::Coalesce,
3361                    exprs: vec![null(), lit(2), null()],
3362                },
3363                output: lit(2),
3364            },
3365            TestCase {
3366                input: MirScalarExpr::CallVariadic {
3367                    func: VariadicFunc::Coalesce,
3368                    exprs: vec![null(), col(0), null(), col(1), lit(2), lit(3)],
3369                },
3370                output: MirScalarExpr::CallVariadic {
3371                    func: VariadicFunc::Coalesce,
3372                    exprs: vec![col(0), col(1), lit(2)],
3373                },
3374            },
3375            TestCase {
3376                input: MirScalarExpr::CallVariadic {
3377                    func: VariadicFunc::Coalesce,
3378                    exprs: vec![col(0), col(2), col(1)],
3379                },
3380                output: MirScalarExpr::CallVariadic {
3381                    func: VariadicFunc::Coalesce,
3382                    exprs: vec![col(0), col(2)],
3383                },
3384            },
3385            TestCase {
3386                input: MirScalarExpr::CallVariadic {
3387                    func: VariadicFunc::Coalesce,
3388                    exprs: vec![lit(1), err(EvalError::DivisionByZero)],
3389                },
3390                output: lit(1),
3391            },
3392            TestCase {
3393                input: MirScalarExpr::CallVariadic {
3394                    func: VariadicFunc::Coalesce,
3395                    exprs: vec![col(0), err(EvalError::DivisionByZero)],
3396                },
3397                output: err(EvalError::DivisionByZero),
3398            },
3399            TestCase {
3400                input: MirScalarExpr::CallVariadic {
3401                    func: VariadicFunc::Coalesce,
3402                    exprs: vec![
3403                        null(),
3404                        err(EvalError::DivisionByZero),
3405                        err(EvalError::NumericFieldOverflow),
3406                    ],
3407                },
3408                output: err(EvalError::DivisionByZero),
3409            },
3410            TestCase {
3411                input: MirScalarExpr::CallVariadic {
3412                    func: VariadicFunc::Coalesce,
3413                    exprs: vec![col(0), err(EvalError::DivisionByZero)],
3414                },
3415                output: err(EvalError::DivisionByZero),
3416            },
3417        ];
3418
3419        for tc in test_cases {
3420            let mut actual = tc.input.clone();
3421            actual.reduce(&relation_type);
3422            assert!(
3423                actual == tc.output,
3424                "input: {}\nactual: {}\nexpected: {}",
3425                tc.input,
3426                actual,
3427                tc.output
3428            );
3429        }
3430    }
3431
3432    proptest! {
3433        #[mz_ore::test]
3434        #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
3435        fn mir_scalar_expr_protobuf_roundtrip(expect in any::<MirScalarExpr>()) {
3436            let actual = protobuf_roundtrip::<_, ProtoMirScalarExpr>(&expect);
3437            assert_ok!(actual);
3438            assert_eq!(actual.unwrap(), expect);
3439        }
3440    }
3441
3442    proptest! {
3443        #[mz_ore::test]
3444        fn domain_limit_protobuf_roundtrip(expect in any::<DomainLimit>()) {
3445            let actual = protobuf_roundtrip::<_, ProtoDomainLimit>(&expect);
3446            assert_ok!(actual);
3447            assert_eq!(actual.unwrap(), expect);
3448        }
3449    }
3450
3451    proptest! {
3452        #[mz_ore::test]
3453        #[cfg_attr(miri, ignore)] // too slow
3454        fn eval_error_protobuf_roundtrip(expect in any::<EvalError>()) {
3455            let actual = protobuf_roundtrip::<_, ProtoEvalError>(&expect);
3456            assert_ok!(actual);
3457            assert_eq!(actual.unwrap(), expect);
3458        }
3459    }
3460}