Skip to main content

mz_expr/scalar/reduce/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallBinary` nodes.
11
12use std::mem;
13
14use itertools::Itertools;
15use mz_pgtz::timezone::{Timezone, TimezoneSpec};
16use mz_repr::adt::datetime::DateTimeUnits;
17use mz_repr::adt::regex::Regex;
18use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena};
19
20use crate::scalar::func::format::DateTimeFormat;
21use crate::scalar::func::variadic::And;
22use crate::scalar::func::{self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone};
23use crate::scalar::like_pattern;
24use crate::{Eval, EvalError, MirScalarExpr};
25
26pub(super) fn reduce_call_binary(
27    e: &mut MirScalarExpr,
28    column_types: &[ReprColumnType],
29    temp_storage: &RowArena,
30) {
31    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
32        unreachable!()
33    };
34
35    // Fold/propagate literal-shaped operands first; precompiles below assume
36    // these have already fired.
37    if expr1.is_literal() && expr2.is_literal() {
38        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
39        return;
40    }
41    if (expr1.is_literal_null() || expr2.is_literal_null()) && func.propagates_nulls() {
42        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
43        return;
44    }
45    if let Some(err) = expr1.as_literal_err() {
46        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
47        return;
48    }
49    if let Some(err) = expr2.as_literal_err() {
50        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
51        return;
52    }
53
54    // Per-function dispatch. Each precompile fires only if its literal-shaped
55    // argument is present; otherwise the call falls through unchanged.
56    match func {
57        BinaryFunc::IsLikeMatchCaseInsensitive(_) if expr2.is_literal() => {
58            // We can at least precompile the regex.
59            precompile_is_like(e, column_types, true);
60        }
61        BinaryFunc::IsLikeMatchCaseSensitive(_) if expr2.is_literal() => {
62            // We can at least precompile the regex.
63            precompile_is_like(e, column_types, false);
64        }
65        BinaryFunc::IsRegexpMatchCaseSensitive(_) | BinaryFunc::IsRegexpMatchCaseInsensitive(_) => {
66            let case_insensitive = matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
67            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
68                *e = match Regex::new(row.unpack_first().unwrap_str(), case_insensitive) {
69                    Ok(regex) => expr1
70                        .take()
71                        .call_unary(UnaryFunc::IsRegexpMatch(func::IsRegexpMatch(regex))),
72                    Err(err) => {
73                        MirScalarExpr::literal(Err(err.into()), e.typ(column_types).scalar_type)
74                    }
75                };
76            }
77        }
78        BinaryFunc::ExtractInterval(_) if expr1.is_literal() => {
79            precompile_date_units(e, column_types, |u| {
80                UnaryFunc::ExtractInterval(func::ExtractInterval(u))
81            })
82        }
83        BinaryFunc::ExtractTime(_) if expr1.is_literal() => {
84            precompile_date_units(e, column_types, |u| {
85                UnaryFunc::ExtractTime(func::ExtractTime(u))
86            })
87        }
88        BinaryFunc::ExtractTimestamp(_) if expr1.is_literal() => {
89            precompile_date_units(e, column_types, |u| {
90                UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(u))
91            })
92        }
93        BinaryFunc::ExtractTimestampTz(_) if expr1.is_literal() => {
94            precompile_date_units(e, column_types, |u| {
95                UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(u))
96            })
97        }
98        BinaryFunc::ExtractDate(_) if expr1.is_literal() => {
99            precompile_date_units(e, column_types, |u| {
100                UnaryFunc::ExtractDate(func::ExtractDate(u))
101            })
102        }
103        BinaryFunc::DatePartInterval(_) if expr1.is_literal() => {
104            precompile_date_units(e, column_types, |u| {
105                UnaryFunc::DatePartInterval(func::DatePartInterval(u))
106            })
107        }
108        BinaryFunc::DatePartTime(_) if expr1.is_literal() => {
109            precompile_date_units(e, column_types, |u| {
110                UnaryFunc::DatePartTime(func::DatePartTime(u))
111            })
112        }
113        BinaryFunc::DatePartTimestamp(_) if expr1.is_literal() => {
114            precompile_date_units(e, column_types, |u| {
115                UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(u))
116            })
117        }
118        BinaryFunc::DatePartTimestampTz(_) if expr1.is_literal() => {
119            precompile_date_units(e, column_types, |u| {
120                UnaryFunc::DatePartTimestampTz(func::DatePartTimestampTz(u))
121            })
122        }
123        BinaryFunc::DateTruncTimestamp(_) if expr1.is_literal() => {
124            precompile_date_units(e, column_types, |u| {
125                UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(u))
126            })
127        }
128        BinaryFunc::DateTruncTimestampTz(_) if expr1.is_literal() => {
129            precompile_date_units(e, column_types, |u| {
130                UnaryFunc::DateTruncTimestampTz(func::DateTruncTimestampTz(u))
131            })
132        }
133        BinaryFunc::TimezoneTimestampBinary(_) if expr1.is_literal() => {
134            // If the timezone argument is a literal, and we're applying the function on many rows at the same
135            // time we really don't want to parse it again and again, so we parse it once and embed it into the
136            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
137            precompile_timezone(e, column_types, |tz| {
138                UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz))
139            });
140        }
141        BinaryFunc::TimezoneTimestampTzBinary(_) if expr1.is_literal() => {
142            precompile_timezone(e, column_types, |tz| {
143                UnaryFunc::TimezoneTimestampTz(func::TimezoneTimestampTz(tz))
144            });
145        }
146        BinaryFunc::ToCharTimestamp(_) if expr2.is_literal() => {
147            precompile_to_char(e, |format_string, format| {
148                UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
149                    format_string,
150                    format,
151                })
152            });
153        }
154        BinaryFunc::ToCharTimestampTz(_) if expr2.is_literal() => {
155            precompile_to_char(e, |format_string, format| {
156                UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
157                    format_string,
158                    format,
159                })
160            });
161        }
162        BinaryFunc::Eq(_) | BinaryFunc::NotEq(_) if expr2 < expr1 => {
163            // Canonically order elements so that deduplication works better.
164            // Also, the `Literal([c1, c2]) = record_create(e1, e2)` matching
165            // below relies on this canonical ordering.
166            mem::swap(expr1, expr2);
167        }
168        _ => reduce_call_binary_eq_record(e),
169    }
170}
171
172/// Decomposes equality with `RecordCreate` on the right-hand side.
173///
174/// Handles two cases:
175/// - `Literal([c1, ...]) = record_create(e1, ...)` → `c1 = e1 AND ...`
176/// - `record_create(a1, ...) = record_create(b1, ...)` → `a1 = b1 AND ...`
177///
178/// `MapFilterProject::literal_constraints` relies on the first transform,
179/// because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
180fn reduce_call_binary_eq_record(e: &mut MirScalarExpr) {
181    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
182        unreachable!()
183    };
184
185    match (&*func, &**expr1, &**expr2) {
186        (
187            BinaryFunc::Eq(_),
188            MirScalarExpr::Literal(
189                Ok(lit_row),
190                ReprColumnType {
191                    scalar_type:
192                        ReprScalarType::Record {
193                            fields: field_types,
194                            ..
195                        },
196                    ..
197                },
198            ),
199            MirScalarExpr::CallVariadic {
200                func: VariadicFunc::RecordCreate(..),
201                exprs: rec_create_args,
202            },
203        ) => {
204            // Literal([c1, c2]) = record_create(e1, e2)
205            //  -->
206            // c1 = e1 AND c2 = e2
207            //
208            // (Records are represented as lists.)
209            //
210            // `MapFilterProject::literal_constraints` relies on this transform,
211            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
212            if let Datum::List(datum_list) = lit_row.unpack_first() {
213                *e = MirScalarExpr::call_variadic(
214                    And,
215                    datum_list
216                        .iter()
217                        .zip_eq(field_types)
218                        .zip_eq(rec_create_args)
219                        .map(|((d, typ), a)| {
220                            MirScalarExpr::literal_ok(d, typ.scalar_type.clone())
221                                .call_binary(a.clone(), func::Eq)
222                        })
223                        .collect(),
224                );
225            }
226        }
227        (
228            BinaryFunc::Eq(_),
229            MirScalarExpr::CallVariadic {
230                func: VariadicFunc::RecordCreate(..),
231                exprs: rec_create_args1,
232            },
233            MirScalarExpr::CallVariadic {
234                func: VariadicFunc::RecordCreate(..),
235                exprs: rec_create_args2,
236            },
237        ) => {
238            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
239            //  -->
240            // a1 = b1 AND a2 = b2 AND ...
241            //
242            // This is similar to the previous reduction, but this one kicks in also
243            // when only some (or none) of the record fields are literals. This
244            // enables the discovery of literal constraints for those fields.
245            //
246            // Note that there is a similar decomposition in
247            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
248            // pipeline than the compilation of IN lists to `record_create`.
249            *e = MirScalarExpr::call_variadic(
250                And,
251                rec_create_args1
252                    .into_iter()
253                    .zip_eq(rec_create_args2)
254                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
255                    .collect(),
256            );
257        }
258        _ => {}
259    }
260}
261
262/// Specializes a binary date/time function call whose units argument is a
263/// literal string. Produces either a unary call with the parsed units baked
264/// in, or a literal `UnknownUnits` error.
265fn precompile_date_units<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
266where
267    F: FnOnce(DateTimeUnits) -> UnaryFunc,
268{
269    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
270        unreachable!()
271    };
272    let units_str = expr1.as_literal_str().unwrap();
273    *e = match units_str.parse::<DateTimeUnits>() {
274        Ok(units) => MirScalarExpr::CallUnary {
275            func: build_unary(units),
276            expr: Box::new(expr2.take()),
277        },
278        Err(_) => MirScalarExpr::literal(
279            Err(EvalError::UnknownUnits(units_str.into())),
280            e.typ(column_types).scalar_type,
281        ),
282    };
283}
284
285/// Specializes a binary timezone-applying function call whose timezone
286/// argument is a literal. Produces either a unary call with the parsed
287/// timezone baked in, or a literal error.
288fn precompile_timezone<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
289where
290    F: FnOnce(Timezone) -> UnaryFunc,
291{
292    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
293        unreachable!()
294    };
295    let tz_str = expr1.as_literal_str().unwrap();
296    *e = match parse_timezone(tz_str, TimezoneSpec::Posix) {
297        Ok(tz) => MirScalarExpr::CallUnary {
298            func: build_unary(tz),
299            expr: Box::new(expr2.take()),
300        },
301        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
302    };
303}
304
305/// Specializes a `to_char_*` binary call whose format-string argument is a
306/// literal, by compiling the format into the unary form.
307fn precompile_to_char<F>(e: &mut MirScalarExpr, build_unary: F)
308where
309    F: FnOnce(String, DateTimeFormat) -> UnaryFunc,
310{
311    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
312        unreachable!()
313    };
314    let format_str = expr2.as_literal_str().unwrap().to_owned();
315    let compiled = DateTimeFormat::compile(&format_str);
316    *e = MirScalarExpr::CallUnary {
317        func: build_unary(format_str, compiled),
318        expr: Box::new(expr1.take()),
319    };
320}
321
322/// Specializes an `IsLikeMatch{CaseSensitive,CaseInsensitive}` binary call
323/// whose pattern argument is a literal, by precompiling the matcher.
324fn precompile_is_like(
325    e: &mut MirScalarExpr,
326    column_types: &[ReprColumnType],
327    case_insensitive: bool,
328) {
329    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
330        unreachable!()
331    };
332    let pattern = expr2.as_literal_str().unwrap();
333    *e = match like_pattern::compile(pattern, case_insensitive) {
334        Ok(matcher) => expr1
335            .take()
336            .call_unary(UnaryFunc::IsLikeMatch(func::IsLikeMatch(matcher))),
337        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
338    };
339}