1use std::mem;
13
14use itertools::Itertools;
15use mz_pgtz::timezone::{Timezone, TimezoneSpec};
16use mz_repr::adt::datetime::DateTimeUnits;
17use mz_repr::adt::interval::Interval;
18use mz_repr::adt::regex::Regex;
19use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena};
20
21use crate::scalar::func::format::DateTimeFormat;
22use crate::scalar::func::variadic::And;
23use crate::scalar::func::{self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone};
24use crate::scalar::like_pattern;
25use crate::{Eval, EvalError, MirScalarExpr};
26
27pub(super) fn reduce_call_binary(
28 e: &mut MirScalarExpr,
29 column_types: &[ReprColumnType],
30 temp_storage: &RowArena,
31) {
32 let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
33 unreachable!()
34 };
35
36 if expr1.is_literal() && expr2.is_literal() {
39 *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
40 return;
41 }
42 if (expr1.is_literal_null() || expr2.is_literal_null()) && func.propagates_nulls() {
43 *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
44 return;
45 }
46 if let Some(err) = expr1.as_literal_err() {
47 *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
48 return;
49 }
50 if let Some(err) = expr2.as_literal_err() {
51 *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
52 return;
53 }
54
55 if reduce_call_binary_identity(e) {
58 return;
59 }
60 let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
61 unreachable!()
62 };
63
64 match func {
67 BinaryFunc::IsLikeMatchCaseInsensitive(_) if expr2.is_literal() => {
68 precompile_is_like(e, column_types, true);
70 }
71 BinaryFunc::IsLikeMatchCaseSensitive(_) if expr2.is_literal() => {
72 precompile_is_like(e, column_types, false);
74 }
75 BinaryFunc::IsRegexpMatchCaseSensitive(_) | BinaryFunc::IsRegexpMatchCaseInsensitive(_) => {
76 let case_insensitive = matches!(func, BinaryFunc::IsRegexpMatchCaseInsensitive(_));
77 if let MirScalarExpr::Literal(Ok(row), _) = &**expr2 {
78 *e = match Regex::new(row.unpack_first().unwrap_str(), case_insensitive) {
79 Ok(regex) => expr1
80 .take()
81 .call_unary(UnaryFunc::IsRegexpMatch(func::IsRegexpMatch(regex))),
82 Err(err) => {
83 MirScalarExpr::literal(Err(err.into()), e.typ(column_types).scalar_type)
84 }
85 };
86 }
87 }
88 BinaryFunc::ExtractInterval(_) if expr1.is_literal() => {
89 precompile_date_units(e, column_types, |u| {
90 UnaryFunc::ExtractInterval(func::ExtractInterval(u))
91 })
92 }
93 BinaryFunc::ExtractTime(_) if expr1.is_literal() => {
94 precompile_date_units(e, column_types, |u| {
95 UnaryFunc::ExtractTime(func::ExtractTime(u))
96 })
97 }
98 BinaryFunc::ExtractTimestamp(_) if expr1.is_literal() => {
99 precompile_date_units(e, column_types, |u| {
100 UnaryFunc::ExtractTimestamp(func::ExtractTimestamp(u))
101 })
102 }
103 BinaryFunc::ExtractTimestampTz(_) if expr1.is_literal() => {
104 precompile_date_units(e, column_types, |u| {
105 UnaryFunc::ExtractTimestampTz(func::ExtractTimestampTz(u))
106 })
107 }
108 BinaryFunc::ExtractDate(_) if expr1.is_literal() => {
109 precompile_date_units(e, column_types, |u| {
110 UnaryFunc::ExtractDate(func::ExtractDate(u))
111 })
112 }
113 BinaryFunc::DatePartInterval(_) if expr1.is_literal() => {
114 precompile_date_units(e, column_types, |u| {
115 UnaryFunc::DatePartInterval(func::DatePartInterval(u))
116 })
117 }
118 BinaryFunc::DatePartTime(_) if expr1.is_literal() => {
119 precompile_date_units(e, column_types, |u| {
120 UnaryFunc::DatePartTime(func::DatePartTime(u))
121 })
122 }
123 BinaryFunc::DatePartTimestamp(_) if expr1.is_literal() => {
124 precompile_date_units(e, column_types, |u| {
125 UnaryFunc::DatePartTimestamp(func::DatePartTimestamp(u))
126 })
127 }
128 BinaryFunc::DatePartTimestampTz(_) if expr1.is_literal() => {
129 precompile_date_units(e, column_types, |u| {
130 UnaryFunc::DatePartTimestampTz(func::DatePartTimestampTz(u))
131 })
132 }
133 BinaryFunc::DateTruncTimestamp(_) if expr1.is_literal() => {
134 precompile_date_units(e, column_types, |u| {
135 UnaryFunc::DateTruncTimestamp(func::DateTruncTimestamp(u))
136 })
137 }
138 BinaryFunc::DateTruncTimestampTz(_) if expr1.is_literal() => {
139 precompile_date_units(e, column_types, |u| {
140 UnaryFunc::DateTruncTimestampTz(func::DateTruncTimestampTz(u))
141 })
142 }
143 BinaryFunc::TimezoneTimestampBinary(_) if expr1.is_literal() => {
144 precompile_timezone(e, column_types, |tz| {
148 UnaryFunc::TimezoneTimestamp(func::TimezoneTimestamp(tz))
149 });
150 }
151 BinaryFunc::TimezoneTimestampTzBinary(_) if expr1.is_literal() => {
152 precompile_timezone(e, column_types, |tz| {
153 UnaryFunc::TimezoneTimestampTz(func::TimezoneTimestampTz(tz))
154 });
155 }
156 BinaryFunc::ToCharTimestamp(_) if expr2.is_literal() => {
157 precompile_to_char(e, |format_string, format| {
158 UnaryFunc::ToCharTimestamp(func::ToCharTimestamp {
159 format_string,
160 format,
161 })
162 });
163 }
164 BinaryFunc::ToCharTimestampTz(_) if expr2.is_literal() => {
165 precompile_to_char(e, |format_string, format| {
166 UnaryFunc::ToCharTimestampTz(func::ToCharTimestampTz {
167 format_string,
168 format,
169 })
170 });
171 }
172 BinaryFunc::Eq(_) | BinaryFunc::NotEq(_) if expr2 < expr1 => {
173 mem::swap(expr1, expr2);
177 }
178 _ => reduce_call_binary_eq_record(e),
179 }
180}
181
182fn reduce_call_binary_identity(e: &mut MirScalarExpr) -> bool {
198 use BinaryFunc::*;
199 let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
200 unreachable!()
201 };
202 let zero_interval = Datum::Interval(Interval::default());
206 let (identity, commutes) = match func {
207 AddInt16(_) => (Datum::Int16(0), true),
208 AddInt32(_) => (Datum::Int32(0), true),
209 AddInt64(_) => (Datum::Int64(0), true),
210 AddUint16(_) => (Datum::UInt16(0), true),
211 AddUint32(_) => (Datum::UInt32(0), true),
212 AddUint64(_) => (Datum::UInt64(0), true),
213 SubInt16(_) => (Datum::Int16(0), false),
214 SubInt32(_) => (Datum::Int32(0), false),
215 SubInt64(_) => (Datum::Int64(0), false),
216 SubUint16(_) => (Datum::UInt16(0), false),
217 SubUint32(_) => (Datum::UInt32(0), false),
218 SubUint64(_) => (Datum::UInt64(0), false),
219 MulInt16(_) => (Datum::Int16(1), true),
220 MulInt32(_) => (Datum::Int32(1), true),
221 MulInt64(_) => (Datum::Int64(1), true),
222 MulUint16(_) => (Datum::UInt16(1), true),
223 MulUint32(_) => (Datum::UInt32(1), true),
224 MulUint64(_) => (Datum::UInt64(1), true),
225 DivInt16(_) => (Datum::Int16(1), false),
226 DivInt32(_) => (Datum::Int32(1), false),
227 DivInt64(_) => (Datum::Int64(1), false),
228 DivUint16(_) => (Datum::UInt16(1), false),
229 DivUint32(_) => (Datum::UInt32(1), false),
230 DivUint64(_) => (Datum::UInt64(1), false),
231 AddInterval(_) => (zero_interval, true),
235 SubInterval(_)
236 | AddTimestampInterval(_)
237 | AddTimestampTzInterval(_)
238 | AddTimeInterval(_)
239 | SubTimestampInterval(_)
240 | SubTimestampTzInterval(_)
241 | SubTimeInterval(_) => (zero_interval, false),
242 BitOrInt16(_) | BitXorInt16(_) => (Datum::Int16(0), true),
245 BitOrInt32(_) | BitXorInt32(_) => (Datum::Int32(0), true),
246 BitOrInt64(_) | BitXorInt64(_) => (Datum::Int64(0), true),
247 BitOrUint16(_) | BitXorUint16(_) => (Datum::UInt16(0), true),
248 BitOrUint32(_) | BitXorUint32(_) => (Datum::UInt32(0), true),
249 BitOrUint64(_) | BitXorUint64(_) => (Datum::UInt64(0), true),
250 BitAndInt16(_) => (Datum::Int16(-1), true),
251 BitAndInt32(_) => (Datum::Int32(-1), true),
252 BitAndInt64(_) => (Datum::Int64(-1), true),
253 BitAndUint16(_) => (Datum::UInt16(u16::MAX), true),
254 BitAndUint32(_) => (Datum::UInt32(u32::MAX), true),
255 BitAndUint64(_) => (Datum::UInt64(u64::MAX), true),
256 BitShiftLeftInt16(_)
257 | BitShiftRightInt16(_)
258 | BitShiftLeftInt32(_)
259 | BitShiftRightInt32(_)
260 | BitShiftLeftInt64(_)
261 | BitShiftRightInt64(_)
262 | BitShiftLeftUint16(_)
263 | BitShiftRightUint16(_)
264 | BitShiftLeftUint32(_)
265 | BitShiftRightUint32(_)
266 | BitShiftLeftUint64(_)
267 | BitShiftRightUint64(_) => (Datum::Int32(0), false),
268 Trim(_) | TrimLeading(_) | TrimTrailing(_) => (Datum::String(""), false),
270 _ => return false,
271 };
272 if matches!(expr2.as_literal(), Some(Ok(d)) if d == identity) {
273 *e = expr1.take();
274 true
275 } else if commutes && matches!(expr1.as_literal(), Some(Ok(d)) if d == identity) {
276 *e = expr2.take();
277 true
278 } else {
279 false
280 }
281}
282
283fn reduce_call_binary_eq_record(e: &mut MirScalarExpr) {
292 let MirScalarExpr::CallBinary { func, expr1, expr2 } = e else {
293 unreachable!()
294 };
295
296 match (&*func, &**expr1, &**expr2) {
297 (
298 BinaryFunc::Eq(_),
299 MirScalarExpr::Literal(
300 Ok(lit_row),
301 ReprColumnType {
302 scalar_type:
303 ReprScalarType::Record {
304 fields: field_types,
305 ..
306 },
307 ..
308 },
309 ),
310 MirScalarExpr::CallVariadic {
311 func: VariadicFunc::RecordCreate(..),
312 exprs: rec_create_args,
313 },
314 ) => {
315 if let Datum::List(datum_list) = lit_row.unpack_first() {
324 *e = MirScalarExpr::call_variadic(
325 And,
326 datum_list
327 .iter()
328 .zip_eq(field_types)
329 .zip_eq(rec_create_args)
330 .map(|((d, typ), a)| {
331 MirScalarExpr::literal_ok(d, typ.scalar_type.clone())
332 .call_binary(a.clone(), func::Eq)
333 })
334 .collect(),
335 );
336 }
337 }
338 (
339 BinaryFunc::Eq(_),
340 MirScalarExpr::CallVariadic {
341 func: VariadicFunc::RecordCreate(..),
342 exprs: rec_create_args1,
343 },
344 MirScalarExpr::CallVariadic {
345 func: VariadicFunc::RecordCreate(..),
346 exprs: rec_create_args2,
347 },
348 ) => {
349 *e = MirScalarExpr::call_variadic(
361 And,
362 rec_create_args1
363 .into_iter()
364 .zip_eq(rec_create_args2)
365 .map(|(a, b)| a.clone().call_binary(b.clone(), func::Eq))
366 .collect(),
367 );
368 }
369 _ => {}
370 }
371}
372
373fn precompile_date_units<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
377where
378 F: FnOnce(DateTimeUnits) -> UnaryFunc,
379{
380 let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
381 unreachable!()
382 };
383 let units_str = expr1.as_literal_str().unwrap();
384 *e = match units_str.parse::<DateTimeUnits>() {
385 Ok(units) => MirScalarExpr::CallUnary {
386 func: build_unary(units),
387 expr: Box::new(expr2.take()),
388 },
389 Err(_) => MirScalarExpr::literal(
390 Err(EvalError::UnknownUnits(units_str.into())),
391 e.typ(column_types).scalar_type,
392 ),
393 };
394}
395
396fn precompile_timezone<F>(e: &mut MirScalarExpr, column_types: &[ReprColumnType], build_unary: F)
400where
401 F: FnOnce(Timezone) -> UnaryFunc,
402{
403 let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
404 unreachable!()
405 };
406 let tz_str = expr1.as_literal_str().unwrap();
407 *e = match parse_timezone(tz_str, TimezoneSpec::Posix) {
408 Ok(tz) => MirScalarExpr::CallUnary {
409 func: build_unary(tz),
410 expr: Box::new(expr2.take()),
411 },
412 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
413 };
414}
415
416fn precompile_to_char<F>(e: &mut MirScalarExpr, build_unary: F)
419where
420 F: FnOnce(String, DateTimeFormat) -> UnaryFunc,
421{
422 let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
423 unreachable!()
424 };
425 let format_str = expr2.as_literal_str().unwrap().to_owned();
426 let compiled = DateTimeFormat::compile(&format_str);
427 *e = MirScalarExpr::CallUnary {
428 func: build_unary(format_str, compiled),
429 expr: Box::new(expr1.take()),
430 };
431}
432
433fn precompile_is_like(
436 e: &mut MirScalarExpr,
437 column_types: &[ReprColumnType],
438 case_insensitive: bool,
439) {
440 let MirScalarExpr::CallBinary { expr1, expr2, .. } = e else {
441 unreachable!()
442 };
443 let pattern = expr2.as_literal_str().unwrap();
444 *e = match like_pattern::compile(pattern, case_insensitive) {
445 Ok(matcher) => expr1
446 .take()
447 .call_unary(UnaryFunc::IsLikeMatch(func::IsLikeMatch(matcher))),
448 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
449 };
450}
451
452#[cfg(test)]
453mod tests {
454 use mz_repr::adt::interval::Interval;
455 use mz_repr::{Datum, ReprScalarType};
456
457 use crate::MirScalarExpr;
458 use crate::scalar::func;
459
460 #[mz_ore::test]
461 fn identity_operand_folds() {
462 let int32 = [ReprScalarType::Int32.nullable(true)];
463 let col = || MirScalarExpr::column(0);
464 let lit = |v| MirScalarExpr::literal_ok(Datum::Int32(v), ReprScalarType::Int32);
465
466 for mut e in [
469 col().call_binary(lit(0), func::AddInt32),
470 lit(0).call_binary(col(), func::AddInt32),
471 col().call_binary(lit(0), func::SubInt32),
472 col().call_binary(lit(1), func::MulInt32),
473 lit(1).call_binary(col(), func::MulInt32),
474 col().call_binary(lit(1), func::DivInt32),
475 col().call_binary(lit(0), func::BitOrInt32),
476 col().call_binary(lit(0), func::BitXorInt32),
477 col().call_binary(lit(-1), func::BitAndInt32),
478 col().call_binary(lit(0), func::BitShiftLeftInt32),
479 col().call_binary(lit(0), func::BitShiftRightInt32),
480 ] {
481 e.reduce(&int32);
482 assert_eq!(e, col(), "expected fold to the column");
483 }
484
485 for mut e in [
488 col().call_binary(lit(1), func::AddInt32),
489 lit(0).call_binary(col(), func::SubInt32),
490 lit(1).call_binary(col(), func::DivInt32),
491 col().call_binary(lit(0), func::BitAndInt32),
492 ] {
493 e.reduce(&int32);
494 assert_ne!(e, col(), "expected no fold to the column");
495 }
496 }
497
498 #[mz_ore::test]
499 fn identity_operand_folds_trim_and_interval() {
500 let col = || MirScalarExpr::column(0);
501
502 let string = [ReprScalarType::String.nullable(true)];
503 let empty = || MirScalarExpr::literal_ok(Datum::String(""), ReprScalarType::String);
504 for f in [
505 crate::BinaryFunc::Trim(func::Trim),
506 crate::BinaryFunc::TrimLeading(func::TrimLeading),
507 crate::BinaryFunc::TrimTrailing(func::TrimTrailing),
508 ] {
509 let mut e = col().call_binary(empty(), f);
510 e.reduce(&string);
511 assert_eq!(e, col());
512 }
513
514 let timestamp = [ReprScalarType::Timestamp {}.nullable(true)];
515 let zero = MirScalarExpr::literal_ok(
516 Datum::Interval(Interval::default()),
517 ReprScalarType::Interval,
518 );
519 let mut e = col().call_binary(zero, func::AddTimestampInterval);
520 e.reduce(×tamp);
521 assert_eq!(e, col());
522 }
523}