Skip to main content

mz_expr/scalar/reduce/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallBinary` nodes.
11
12use std::mem;
13
14use itertools::Itertools;
15use mz_pgtz::timezone::{Timezone, TimezoneSpec};
16use mz_repr::adt::datetime::DateTimeUnits;
17use mz_repr::adt::interval::Interval;
18use mz_repr::adt::regex::Regex;
19use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena};
20
21use crate::scalar::func::format::DateTimeFormat;
22use crate::scalar::func::variadic::And;
23use crate::scalar::func::{self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone};
24use crate::scalar::like_pattern;
25use crate::{Eval, EvalError, MirScalarExpr};
26
27pub(super) fn reduce_call_binary(
28    e: &mut MirScalarExpr,
29    column_types: &[ReprColumnType],
30    temp_storage: &RowArena,
31) {
32    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
33        unreachable!()
34    };
35
36    // Fold/propagate literal-shaped operands first; precompiles below assume
37    // these have already fired.
38    if expr1.is_literal() && expr2.is_literal() {
39        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
40        return;
41    }
42    if (expr1.is_literal_null() || expr2.is_literal_null()) && func.propagates_nulls() {
43        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
44        return;
45    }
46    if let Some(err) = expr1.as_literal_err() {
47        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
48        return;
49    }
50    if let Some(err) = expr2.as_literal_err() {
51        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
52        return;
53    }
54
55    // Calls where a literal operand makes the call the identity function on
56    // the other operand reduce to that operand.
57    if reduce_call_binary_identity(e) {
58        return;
59    }
60    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
61        unreachable!()
62    };
63
64    // Per-function dispatch. Each precompile fires only if its literal-shaped
65    // argument is present; otherwise the call falls through unchanged.
66    match func {
67        BinaryFunc::IsLikeMatchCaseInsensitive(_) if expr2.is_literal() => {
68            // We can at least precompile the regex.
69            precompile_is_like(e, column_types, true);
70        }
71        BinaryFunc::IsLikeMatchCaseSensitive(_) if expr2.is_literal() => {
72            // We can at least precompile the regex.
73            precompile_is_like(e, column_types, false);
74        }
75        BinaryFunc::IsRegexpMatchCaseSensitive(_) | BinaryFunc::IsRegexpMatchCaseInsensitive(_) => {
76            let case_insensitive = matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
77            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
78                *e = match Regex::new(row.unpack_first().unwrap_str(), case_insensitive) {
79                    Ok(regex) => expr1
80                        .take()
81                        .call_unary(UnaryFunc::IsRegexpMatch(func::IsRegexpMatch(regex))),
82                    Err(err) => {
83                        MirScalarExpr::literal(Err(err.into()), e.typ(column_types).scalar_type)
84                    }
85                };
86            }
87        }
88        BinaryFunc::ExtractInterval(_) if expr1.is_literal() => {
89            precompile_date_units(e, column_types, |u| {
90                UnaryFunc::ExtractInterval(func::ExtractInterval(u))
91            })
92        }
93        BinaryFunc::ExtractTime(_) if expr1.is_literal() => {
94            precompile_date_units(e, column_types, |u| {
95                UnaryFunc::ExtractTime(func::ExtractTime(u))
96            })
97        }
98        BinaryFunc::ExtractTimestamp(_) if expr1.is_literal() => {
99            precompile_date_units(e, column_types, |u| {
100                UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(u))
101            })
102        }
103        BinaryFunc::ExtractTimestampTz(_) if expr1.is_literal() => {
104            precompile_date_units(e, column_types, |u| {
105                UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(u))
106            })
107        }
108        BinaryFunc::ExtractDate(_) if expr1.is_literal() => {
109            precompile_date_units(e, column_types, |u| {
110                UnaryFunc::ExtractDate(func::ExtractDate(u))
111            })
112        }
113        BinaryFunc::DatePartInterval(_) if expr1.is_literal() => {
114            precompile_date_units(e, column_types, |u| {
115                UnaryFunc::DatePartInterval(func::DatePartInterval(u))
116            })
117        }
118        BinaryFunc::DatePartTime(_) if expr1.is_literal() => {
119            precompile_date_units(e, column_types, |u| {
120                UnaryFunc::DatePartTime(func::DatePartTime(u))
121            })
122        }
123        BinaryFunc::DatePartTimestamp(_) if expr1.is_literal() => {
124            precompile_date_units(e, column_types, |u| {
125                UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(u))
126            })
127        }
128        BinaryFunc::DatePartTimestampTz(_) if expr1.is_literal() => {
129            precompile_date_units(e, column_types, |u| {
130                UnaryFunc::DatePartTimestampTz(func::DatePartTimestampTz(u))
131            })
132        }
133        BinaryFunc::DateTruncTimestamp(_) if expr1.is_literal() => {
134            precompile_date_units(e, column_types, |u| {
135                UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(u))
136            })
137        }
138        BinaryFunc::DateTruncTimestampTz(_) if expr1.is_literal() => {
139            precompile_date_units(e, column_types, |u| {
140                UnaryFunc::DateTruncTimestampTz(func::DateTruncTimestampTz(u))
141            })
142        }
143        BinaryFunc::TimezoneTimestampBinary(_) if expr1.is_literal() => {
144            // If the timezone argument is a literal, and we're applying the function on many rows at the same
145            // time we really don't want to parse it again and again, so we parse it once and embed it into the
146            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
147            precompile_timezone(e, column_types, |tz| {
148                UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz))
149            });
150        }
151        BinaryFunc::TimezoneTimestampTzBinary(_) if expr1.is_literal() => {
152            precompile_timezone(e, column_types, |tz| {
153                UnaryFunc::TimezoneTimestampTz(func::TimezoneTimestampTz(tz))
154            });
155        }
156        BinaryFunc::ToCharTimestamp(_) if expr2.is_literal() => {
157            precompile_to_char(e, |format_string, format| {
158                UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
159                    format_string,
160                    format,
161                })
162            });
163        }
164        BinaryFunc::ToCharTimestampTz(_) if expr2.is_literal() => {
165            precompile_to_char(e, |format_string, format| {
166                UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
167                    format_string,
168                    format,
169                })
170            });
171        }
172        BinaryFunc::Eq(_) | BinaryFunc::NotEq(_) if expr2 < expr1 => {
173            // Canonically order elements so that deduplication works better.
174            // Also, the `Literal([c1, c2]) = record_create(e1, e2)` matching
175            // below relies on this canonical ordering.
176            mem::swap(expr1, expr2);
177        }
178        _ => reduce_call_binary_eq_record(e),
179    }
180}
181
182/// Rewrites a call whose literal operand makes it the identity function on
183/// the other operand to that operand, returning whether it fired.
184///
185/// Only patterns that can change neither error nor null behavior are listed:
186/// evaluation at the identity operand is infallible for every listed function
187/// (division by one cannot error; adding an all-zero interval cannot leave
188/// the timestamp domain; trimming the empty character set cannot grow the
189/// string), every listed function propagates nulls, and the surviving operand
190/// has the value and type of the original call. Functions that can error even
191/// at their identity are deliberately absent: e.g. `text_concat` and
192/// `repeat(s, 1)` re-validate the length of an oversized operand, so eliding
193/// them would suppress that error. Floats are also absent: `-0.0 + 0.0` is
194/// `+0.0`, so the additive identities are not exact, and we keep the float
195/// story all-or-nothing for legibility. Numerics are absent pending an answer
196/// on result scale.
197fn reduce_call_binary_identity(e: &mut MirScalarExpr) -> bool {
198    use BinaryFunc::*;
199    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
200        unreachable!()
201    };
202    // For each function: the literal datum under which the call is the
203    // identity on the other operand, and whether the function commutes (so
204    // the identity may appear on either side rather than only the right).
205    let zero_interval = Datum::Interval(Interval::default());
206    let (identity, commutes) = match func {
207        AddInt16(_) => (Datum::Int16(0), true),
208        AddInt32(_) => (Datum::Int32(0), true),
209        AddInt64(_) => (Datum::Int64(0), true),
210        AddUint16(_) => (Datum::UInt16(0), true),
211        AddUint32(_) => (Datum::UInt32(0), true),
212        AddUint64(_) => (Datum::UInt64(0), true),
213        SubInt16(_) => (Datum::Int16(0), false),
214        SubInt32(_) => (Datum::Int32(0), false),
215        SubInt64(_) => (Datum::Int64(0), false),
216        SubUint16(_) => (Datum::UInt16(0), false),
217        SubUint32(_) => (Datum::UInt32(0), false),
218        SubUint64(_) => (Datum::UInt64(0), false),
219        MulInt16(_) => (Datum::Int16(1), true),
220        MulInt32(_) => (Datum::Int32(1), true),
221        MulInt64(_) => (Datum::Int64(1), true),
222        MulUint16(_) => (Datum::UInt16(1), true),
223        MulUint32(_) => (Datum::UInt32(1), true),
224        MulUint64(_) => (Datum::UInt64(1), true),
225        DivInt16(_) => (Datum::Int16(1), false),
226        DivInt32(_) => (Datum::Int32(1), false),
227        DivInt64(_) => (Datum::Int64(1), false),
228        DivUint16(_) => (Datum::UInt16(1), false),
229        DivUint32(_) => (Datum::UInt32(1), false),
230        DivUint64(_) => (Datum::UInt64(1), false),
231        // Temporal: an all-zero interval added to or subtracted from a
232        // timestamp, time, or interval. (Dates are absent: their interval
233        // arithmetic changes the type to timestamp.)
234        AddInterval(_) => (zero_interval, true),
235        SubInterval(_)
236        | AddTimestampInterval(_)
237        | AddTimestampTzInterval(_)
238        | AddTimeInterval(_)
239        | SubTimestampInterval(_)
240        | SubTimestampTzInterval(_)
241        | SubTimeInterval(_) => (zero_interval, false),
242        // Bitwise: zero is the identity for or/xor, all-ones for and, and a
243        // shift distance of zero (always an `int4`) leaves the value alone.
244        BitOrInt16(_) | BitXorInt16(_) => (Datum::Int16(0), true),
245        BitOrInt32(_) | BitXorInt32(_) => (Datum::Int32(0), true),
246        BitOrInt64(_) | BitXorInt64(_) => (Datum::Int64(0), true),
247        BitOrUint16(_) | BitXorUint16(_) => (Datum::UInt16(0), true),
248        BitOrUint32(_) | BitXorUint32(_) => (Datum::UInt32(0), true),
249        BitOrUint64(_) | BitXorUint64(_) => (Datum::UInt64(0), true),
250        BitAndInt16(_) => (Datum::Int16(-1), true),
251        BitAndInt32(_) => (Datum::Int32(-1), true),
252        BitAndInt64(_) => (Datum::Int64(-1), true),
253        BitAndUint16(_) => (Datum::UInt16(u16::MAX), true),
254        BitAndUint32(_) => (Datum::UInt32(u32::MAX), true),
255        BitAndUint64(_) => (Datum::UInt64(u64::MAX), true),
256        BitShiftLeftInt16(_)
257        | BitShiftRightInt16(_)
258        | BitShiftLeftInt32(_)
259        | BitShiftRightInt32(_)
260        | BitShiftLeftInt64(_)
261        | BitShiftRightInt64(_)
262        | BitShiftLeftUint16(_)
263        | BitShiftRightUint16(_)
264        | BitShiftLeftUint32(_)
265        | BitShiftRightUint32(_)
266        | BitShiftLeftUint64(_)
267        | BitShiftRightUint64(_) => (Datum::Int32(0), false),
268        // Trimming the empty set of characters.
269        Trim(_) | TrimLeading(_) | TrimTrailing(_) => (Datum::String(""), false),
270        _ => return false,
271    };
272    if matches!(expr2.as_literal(), Some(Ok(d)) if d == identity) {
273        *e = expr1.take();
274        true
275    } else if commutes && matches!(expr1.as_literal(), Some(Ok(d)) if d == identity) {
276        *e = expr2.take();
277        true
278    } else {
279        false
280    }
281}
282
283/// Decomposes equality with `RecordCreate` on the right-hand side.
284///
285/// Handles two cases:
286/// - `Literal([c1, ...]) = record_create(e1, ...)` → `c1 = e1 AND ...`
287/// - `record_create(a1, ...) = record_create(b1, ...)` → `a1 = b1 AND ...`
288///
289/// `MapFilterProject::literal_constraints` relies on the first transform,
290/// because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
291fn reduce_call_binary_eq_record(e: &mut MirScalarExpr) {
292    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
293        unreachable!()
294    };
295
296    match (&*func, &**expr1, &**expr2) {
297        (
298            BinaryFunc::Eq(_),
299            MirScalarExpr::Literal(
300                Ok(lit_row),
301                ReprColumnType {
302                    scalar_type:
303                        ReprScalarType::Record {
304                            fields: field_types,
305                            ..
306                        },
307                    ..
308                },
309            ),
310            MirScalarExpr::CallVariadic {
311                func: VariadicFunc::RecordCreate(..),
312                exprs: rec_create_args,
313            },
314        ) => {
315            // Literal([c1, c2]) = record_create(e1, e2)
316            //  -->
317            // c1 = e1 AND c2 = e2
318            //
319            // (Records are represented as lists.)
320            //
321            // `MapFilterProject::literal_constraints` relies on this transform,
322            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
323            if let Datum::List(datum_list) = lit_row.unpack_first() {
324                *e = MirScalarExpr::call_variadic(
325                    And,
326                    datum_list
327                        .iter()
328                        .zip_eq(field_types)
329                        .zip_eq(rec_create_args)
330                        .map(|((d, typ), a)| {
331                            MirScalarExpr::literal_ok(d, typ.scalar_type.clone())
332                                .call_binary(a.clone(), func::Eq)
333                        })
334                        .collect(),
335                );
336            }
337        }
338        (
339            BinaryFunc::Eq(_),
340            MirScalarExpr::CallVariadic {
341                func: VariadicFunc::RecordCreate(..),
342                exprs: rec_create_args1,
343            },
344            MirScalarExpr::CallVariadic {
345                func: VariadicFunc::RecordCreate(..),
346                exprs: rec_create_args2,
347            },
348        ) => {
349            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
350            //  -->
351            // a1 = b1 AND a2 = b2 AND ...
352            //
353            // This is similar to the previous reduction, but this one kicks in also
354            // when only some (or none) of the record fields are literals. This
355            // enables the discovery of literal constraints for those fields.
356            //
357            // Note that there is a similar decomposition in
358            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
359            // pipeline than the compilation of IN lists to `record_create`.
360            *e = MirScalarExpr::call_variadic(
361                And,
362                rec_create_args1
363                    .into_iter()
364                    .zip_eq(rec_create_args2)
365                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
366                    .collect(),
367            );
368        }
369        _ => {}
370    }
371}
372
373/// Specializes a binary date/time function call whose units argument is a
374/// literal string. Produces either a unary call with the parsed units baked
375/// in, or a literal `UnknownUnits` error.
376fn precompile_date_units<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
377where
378    F: FnOnce(DateTimeUnits) -> UnaryFunc,
379{
380    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
381        unreachable!()
382    };
383    let units_str = expr1.as_literal_str().unwrap();
384    *e = match units_str.parse::<DateTimeUnits>() {
385        Ok(units) => MirScalarExpr::CallUnary {
386            func: build_unary(units),
387            expr: Box::new(expr2.take()),
388        },
389        Err(_) => MirScalarExpr::literal(
390            Err(EvalError::UnknownUnits(units_str.into())),
391            e.typ(column_types).scalar_type,
392        ),
393    };
394}
395
396/// Specializes a binary timezone-applying function call whose timezone
397/// argument is a literal. Produces either a unary call with the parsed
398/// timezone baked in, or a literal error.
399fn precompile_timezone<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
400where
401    F: FnOnce(Timezone) -> UnaryFunc,
402{
403    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
404        unreachable!()
405    };
406    let tz_str = expr1.as_literal_str().unwrap();
407    *e = match parse_timezone(tz_str, TimezoneSpec::Posix) {
408        Ok(tz) => MirScalarExpr::CallUnary {
409            func: build_unary(tz),
410            expr: Box::new(expr2.take()),
411        },
412        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
413    };
414}
415
416/// Specializes a `to_char_*` binary call whose format-string argument is a
417/// literal, by compiling the format into the unary form.
418fn precompile_to_char<F>(e: &mut MirScalarExpr, build_unary: F)
419where
420    F: FnOnce(String, DateTimeFormat) -> UnaryFunc,
421{
422    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
423        unreachable!()
424    };
425    let format_str = expr2.as_literal_str().unwrap().to_owned();
426    let compiled = DateTimeFormat::compile(&format_str);
427    *e = MirScalarExpr::CallUnary {
428        func: build_unary(format_str, compiled),
429        expr: Box::new(expr1.take()),
430    };
431}
432
433/// Specializes an `IsLikeMatch{CaseSensitive,CaseInsensitive}` binary call
434/// whose pattern argument is a literal, by precompiling the matcher.
435fn precompile_is_like(
436    e: &mut MirScalarExpr,
437    column_types: &[ReprColumnType],
438    case_insensitive: bool,
439) {
440    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
441        unreachable!()
442    };
443    let pattern = expr2.as_literal_str().unwrap();
444    *e = match like_pattern::compile(pattern, case_insensitive) {
445        Ok(matcher) => expr1
446            .take()
447            .call_unary(UnaryFunc::IsLikeMatch(func::IsLikeMatch(matcher))),
448        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
449    };
450}
451
452#[cfg(test)]
453mod tests {
454    use mz_repr::adt::interval::Interval;
455    use mz_repr::{Datum, ReprScalarType};
456
457    use crate::MirScalarExpr;
458    use crate::scalar::func;
459
460    #[mz_ore::test]
461    fn identity_operand_folds() {
462        let int32 = [ReprScalarType::Int32.nullable(true)];
463        let col = || MirScalarExpr::column(0);
464        let lit = |v| MirScalarExpr::literal_ok(Datum::Int32(v), ReprScalarType::Int32);
465
466        // Identity operands fold to the other operand, on either side for
467        // commutative functions and on the right for the rest.
468        for mut e in [
469            col().call_binary(lit(0), func::AddInt32),
470            lit(0).call_binary(col(), func::AddInt32),
471            col().call_binary(lit(0), func::SubInt32),
472            col().call_binary(lit(1), func::MulInt32),
473            lit(1).call_binary(col(), func::MulInt32),
474            col().call_binary(lit(1), func::DivInt32),
475            col().call_binary(lit(0), func::BitOrInt32),
476            col().call_binary(lit(0), func::BitXorInt32),
477            col().call_binary(lit(-1), func::BitAndInt32),
478            col().call_binary(lit(0), func::BitShiftLeftInt32),
479            col().call_binary(lit(0), func::BitShiftRightInt32),
480        ] {
481            e.reduce(&int32);
482            assert_eq!(e, col(), "expected fold to the column");
483        }
484
485        // Non-identity literals, and identities on the wrong side of a
486        // non-commutative function, do not fold to the operand.
487        for mut e in [
488            col().call_binary(lit(1), func::AddInt32),
489            lit(0).call_binary(col(), func::SubInt32),
490            lit(1).call_binary(col(), func::DivInt32),
491            col().call_binary(lit(0), func::BitAndInt32),
492        ] {
493            e.reduce(&int32);
494            assert_ne!(e, col(), "expected no fold to the column");
495        }
496    }
497
498    #[mz_ore::test]
499    fn identity_operand_folds_trim_and_interval() {
500        let col = || MirScalarExpr::column(0);
501
502        let string = [ReprScalarType::String.nullable(true)];
503        let empty = || MirScalarExpr::literal_ok(Datum::String(""), ReprScalarType::String);
504        for f in [
505            crate::BinaryFunc::Trim(func::Trim),
506            crate::BinaryFunc::TrimLeading(func::TrimLeading),
507            crate::BinaryFunc::TrimTrailing(func::TrimTrailing),
508        ] {
509            let mut e = col().call_binary(empty(), f);
510            e.reduce(&string);
511            assert_eq!(e, col());
512        }
513
514        let timestamp = [ReprScalarType::Timestamp {}.nullable(true)];
515        let zero = MirScalarExpr::literal_ok(
516            Datum::Interval(Interval::default()),
517            ReprScalarType::Interval,
518        );
519        let mut e = col().call_binary(zero, func::AddTimestampInterval);
520        e.reduce(&timestamp);
521        assert_eq!(e, col());
522    }
523}