Skip to main content

mz_expr/scalar/reduce/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallBinary` nodes.
11
12use std::mem;
13
14use itertools::Itertools;
15use mz_pgtz::timezone::{Timezone, TimezoneSpec};
16use mz_repr::adt::datetime::DateTimeUnits;
17use mz_repr::adt::regex::Regex;
18use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena};
19
20use crate::EvalError;
21use crate::MirScalarExpr;
22use crate::scalar::func::format::DateTimeFormat;
23use crate::scalar::func::variadic::And;
24use crate::scalar::func::{self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone};
25use crate::scalar::like_pattern;
26
27pub(super) fn reduce_call_binary(
28    e: &mut MirScalarExpr,
29    column_types: &[ReprColumnType],
30    temp_storage: &RowArena,
31) {
32    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
33        unreachable!()
34    };
35
36    // Fold/propagate literal-shaped operands first; precompiles below assume
37    // these have already fired.
38    if expr1.is_literal() && expr2.is_literal() {
39        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
40        return;
41    }
42    if (expr1.is_literal_null() || expr2.is_literal_null()) && func.propagates_nulls() {
43        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
44        return;
45    }
46    if let Some(err) = expr1.as_literal_err() {
47        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
48        return;
49    }
50    if let Some(err) = expr2.as_literal_err() {
51        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
52        return;
53    }
54
55    // Per-function dispatch. Each precompile fires only if its literal-shaped
56    // argument is present; otherwise the call falls through unchanged.
57    match func {
58        BinaryFunc::IsLikeMatchCaseInsensitive(_) if expr2.is_literal() => {
59            // We can at least precompile the regex.
60            precompile_is_like(e, column_types, true);
61        }
62        BinaryFunc::IsLikeMatchCaseSensitive(_) if expr2.is_literal() => {
63            // We can at least precompile the regex.
64            precompile_is_like(e, column_types, false);
65        }
66        BinaryFunc::IsRegexpMatchCaseSensitive(_) | BinaryFunc::IsRegexpMatchCaseInsensitive(_) => {
67            let case_insensitive = matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
68            if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
69                *e = match Regex::new(row.unpack_first().unwrap_str(), case_insensitive) {
70                    Ok(regex) => expr1
71                        .take()
72                        .call_unary(UnaryFunc::IsRegexpMatch(func::IsRegexpMatch(regex))),
73                    Err(err) => {
74                        MirScalarExpr::literal(Err(err.into()), e.typ(column_types).scalar_type)
75                    }
76                };
77            }
78        }
79        BinaryFunc::ExtractInterval(_) if expr1.is_literal() => {
80            precompile_date_units(e, column_types, |u| {
81                UnaryFunc::ExtractInterval(func::ExtractInterval(u))
82            })
83        }
84        BinaryFunc::ExtractTime(_) if expr1.is_literal() => {
85            precompile_date_units(e, column_types, |u| {
86                UnaryFunc::ExtractTime(func::ExtractTime(u))
87            })
88        }
89        BinaryFunc::ExtractTimestamp(_) if expr1.is_literal() => {
90            precompile_date_units(e, column_types, |u| {
91                UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(u))
92            })
93        }
94        BinaryFunc::ExtractTimestampTz(_) if expr1.is_literal() => {
95            precompile_date_units(e, column_types, |u| {
96                UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(u))
97            })
98        }
99        BinaryFunc::ExtractDate(_) if expr1.is_literal() => {
100            precompile_date_units(e, column_types, |u| {
101                UnaryFunc::ExtractDate(func::ExtractDate(u))
102            })
103        }
104        BinaryFunc::DatePartInterval(_) if expr1.is_literal() => {
105            precompile_date_units(e, column_types, |u| {
106                UnaryFunc::DatePartInterval(func::DatePartInterval(u))
107            })
108        }
109        BinaryFunc::DatePartTime(_) if expr1.is_literal() => {
110            precompile_date_units(e, column_types, |u| {
111                UnaryFunc::DatePartTime(func::DatePartTime(u))
112            })
113        }
114        BinaryFunc::DatePartTimestamp(_) if expr1.is_literal() => {
115            precompile_date_units(e, column_types, |u| {
116                UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(u))
117            })
118        }
119        BinaryFunc::DatePartTimestampTz(_) if expr1.is_literal() => {
120            precompile_date_units(e, column_types, |u| {
121                UnaryFunc::DatePartTimestampTz(func::DatePartTimestampTz(u))
122            })
123        }
124        BinaryFunc::DateTruncTimestamp(_) if expr1.is_literal() => {
125            precompile_date_units(e, column_types, |u| {
126                UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(u))
127            })
128        }
129        BinaryFunc::DateTruncTimestampTz(_) if expr1.is_literal() => {
130            precompile_date_units(e, column_types, |u| {
131                UnaryFunc::DateTruncTimestampTz(func::DateTruncTimestampTz(u))
132            })
133        }
134        BinaryFunc::TimezoneTimestampBinary(_) if expr1.is_literal() => {
135            // If the timezone argument is a literal, and we're applying the function on many rows at the same
136            // time we really don't want to parse it again and again, so we parse it once and embed it into the
137            // UnaryFunc enum. The memory footprint of Timezone is small (8 bytes).
138            precompile_timezone(e, column_types, |tz| {
139                UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz))
140            });
141        }
142        BinaryFunc::TimezoneTimestampTzBinary(_) if expr1.is_literal() => {
143            precompile_timezone(e, column_types, |tz| {
144                UnaryFunc::TimezoneTimestampTz(func::TimezoneTimestampTz(tz))
145            });
146        }
147        BinaryFunc::ToCharTimestamp(_) if expr2.is_literal() => {
148            precompile_to_char(e, |format_string, format| {
149                UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
150                    format_string,
151                    format,
152                })
153            });
154        }
155        BinaryFunc::ToCharTimestampTz(_) if expr2.is_literal() => {
156            precompile_to_char(e, |format_string, format| {
157                UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
158                    format_string,
159                    format,
160                })
161            });
162        }
163        BinaryFunc::Eq(_) | BinaryFunc::NotEq(_) if expr2 < expr1 => {
164            // Canonically order elements so that deduplication works better.
165            // Also, the `Literal([c1, c2]) = record_create(e1, e2)` matching
166            // below relies on this canonical ordering.
167            mem::swap(expr1, expr2);
168        }
169        _ => reduce_call_binary_eq_record(e),
170    }
171}
172
173/// Decomposes equality with `RecordCreate` on the right-hand side.
174///
175/// Handles two cases:
176/// - `Literal([c1, ...]) = record_create(e1, ...)` → `c1 = e1 AND ...`
177/// - `record_create(a1, ...) = record_create(b1, ...)` → `a1 = b1 AND ...`
178///
179/// `MapFilterProject::literal_constraints` relies on the first transform,
180/// because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
181fn reduce_call_binary_eq_record(e: &mut MirScalarExpr) {
182    let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
183        unreachable!()
184    };
185
186    match (&*func, &**expr1, &**expr2) {
187        (
188            BinaryFunc::Eq(_),
189            MirScalarExpr::Literal(
190                Ok(lit_row),
191                ReprColumnType {
192                    scalar_type:
193                        ReprScalarType::Record {
194                            fields: field_types,
195                            ..
196                        },
197                    ..
198                },
199            ),
200            MirScalarExpr::CallVariadic {
201                func: VariadicFunc::RecordCreate(..),
202                exprs: rec_create_args,
203            },
204        ) => {
205            // Literal([c1, c2]) = record_create(e1, e2)
206            //  -->
207            // c1 = e1 AND c2 = e2
208            //
209            // (Records are represented as lists.)
210            //
211            // `MapFilterProject::literal_constraints` relies on this transform,
212            // because `(e1,e2) IN ((1,2))` is desugared using `record_create`.
213            if let Datum::List(datum_list) = lit_row.unpack_first() {
214                *e = MirScalarExpr::call_variadic(
215                    And,
216                    datum_list
217                        .iter()
218                        .zip_eq(field_types)
219                        .zip_eq(rec_create_args)
220                        .map(|((d, typ), a)| {
221                            MirScalarExpr::literal_ok(d, typ.scalar_type.clone())
222                                .call_binary(a.clone(), func::Eq)
223                        })
224                        .collect(),
225                );
226            }
227        }
228        (
229            BinaryFunc::Eq(_),
230            MirScalarExpr::CallVariadic {
231                func: VariadicFunc::RecordCreate(..),
232                exprs: rec_create_args1,
233            },
234            MirScalarExpr::CallVariadic {
235                func: VariadicFunc::RecordCreate(..),
236                exprs: rec_create_args2,
237            },
238        ) => {
239            // record_create(a1, a2, ...) = record_create(b1, b2, ...)
240            //  -->
241            // a1 = b1 AND a2 = b2 AND ...
242            //
243            // This is similar to the previous reduction, but this one kicks in also
244            // when only some (or none) of the record fields are literals. This
245            // enables the discovery of literal constraints for those fields.
246            //
247            // Note that there is a similar decomposition in
248            // `mz_sql::plan::transform_ast::Desugarer`, but that is earlier in the
249            // pipeline than the compilation of IN lists to `record_create`.
250            *e = MirScalarExpr::call_variadic(
251                And,
252                rec_create_args1
253                    .into_iter()
254                    .zip_eq(rec_create_args2)
255                    .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
256                    .collect(),
257            );
258        }
259        _ => {}
260    }
261}
262
263/// Specializes a binary date/time function call whose units argument is a
264/// literal string. Produces either a unary call with the parsed units baked
265/// in, or a literal `UnknownUnits` error.
266fn precompile_date_units<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
267where
268    F: FnOnce(DateTimeUnits) -> UnaryFunc,
269{
270    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
271        unreachable!()
272    };
273    let units_str = expr1.as_literal_str().unwrap();
274    *e = match units_str.parse::<DateTimeUnits>() {
275        Ok(units) => MirScalarExpr::CallUnary {
276            func: build_unary(units),
277            expr: Box::new(expr2.take()),
278        },
279        Err(_) => MirScalarExpr::literal(
280            Err(EvalError::UnknownUnits(units_str.into())),
281            e.typ(column_types).scalar_type,
282        ),
283    };
284}
285
286/// Specializes a binary timezone-applying function call whose timezone
287/// argument is a literal. Produces either a unary call with the parsed
288/// timezone baked in, or a literal error.
289fn precompile_timezone<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
290where
291    F: FnOnce(Timezone) -> UnaryFunc,
292{
293    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
294        unreachable!()
295    };
296    let tz_str = expr1.as_literal_str().unwrap();
297    *e = match parse_timezone(tz_str, TimezoneSpec::Posix) {
298        Ok(tz) => MirScalarExpr::CallUnary {
299            func: build_unary(tz),
300            expr: Box::new(expr2.take()),
301        },
302        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
303    };
304}
305
306/// Specializes a `to_char_*` binary call whose format-string argument is a
307/// literal, by compiling the format into the unary form.
308fn precompile_to_char<F>(e: &mut MirScalarExpr, build_unary: F)
309where
310    F: FnOnce(String, DateTimeFormat) -> UnaryFunc,
311{
312    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
313        unreachable!()
314    };
315    let format_str = expr2.as_literal_str().unwrap().to_owned();
316    let compiled = DateTimeFormat::compile(&format_str);
317    *e = MirScalarExpr::CallUnary {
318        func: build_unary(format_str, compiled),
319        expr: Box::new(expr1.take()),
320    };
321}
322
323/// Specializes an `IsLikeMatch{CaseSensitive,CaseInsensitive}` binary call
324/// whose pattern argument is a literal, by precompiling the matcher.
325fn precompile_is_like(
326    e: &mut MirScalarExpr,
327    column_types: &[ReprColumnType],
328    case_insensitive: bool,
329) {
330    let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
331        unreachable!()
332    };
333    let pattern = expr2.as_literal_str().unwrap();
334    *e = match like_pattern::compile(pattern, case_insensitive) {
335        Ok(matcher) => expr1
336            .take()
337            .call_unary(UnaryFunc::IsLikeMatch(func::IsLikeMatch(matcher))),
338        Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
339    };
340}