1use std::collections::BTreeSet;
13use std::mem;
14
15use mz_ore::collections::CollectionExt;
16use mz_pgtz::timezone::TimezoneSpec;
17use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena, SqlScalarType};
18
19use crate::scalar::func::variadic::{Coalesce, ListCreate, ListIndex};
20use crate::scalar::func::{
21 self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone, regexp_replace_parse_flags,
22};
23use crate::{Eval, MirScalarExpr};
24
25pub(super) fn reduce_call_variadic(
26 e: &mut MirScalarExpr,
27 column_types: &[ReprColumnType],
28 temp_storage: &RowArena,
29) {
30 e.flatten_associative();
33
34 let MirScalarExpr::CallVariadic { func, exprs } = e else {
35 unreachable!("`flatten_associative` shouldn't change node type");
36 };
37
38 if *func == Coalesce.into() {
41 simplify_coalesce(e, column_types);
42 return;
43 }
44
45 if exprs.iter().all(|x| x.is_literal()) {
47 *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
48 return;
49 }
50 if func.propagates_nulls() && exprs.iter().any(|x| x.is_literal_null()) {
51 *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
52 return;
53 }
54 if func.propagates_nulls() {
61 if let Some(err) = exprs.iter().find_map(|x| x.as_literal_err()) {
62 *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
63 return;
64 }
65 }
66
67 match func {
70 VariadicFunc::Greatest(_) | VariadicFunc::Least(_) => {
71 reduce_greatest_least(e, column_types);
72 }
73 VariadicFunc::Substr(_)
74 if exprs.len() == 2 && matches!(exprs[1].as_literal(), Some(Ok(Datum::Int32(1)))) =>
75 {
76 *e = exprs.swap_remove(0);
79 }
80 VariadicFunc::RegexpMatch(_)
81 if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
82 {
83 let needle = exprs[1].as_literal_str().unwrap();
84 let flags = if exprs.len() == 3 {
85 exprs[2].as_literal_str().unwrap()
86 } else {
87 ""
88 };
89 *e = match func::build_regex(needle, flags) {
90 Ok(regex) => mem::take(exprs)
91 .into_first()
92 .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
93 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
94 };
95 }
96 VariadicFunc::RegexpReplace(_)
97 if exprs[1].is_literal() && exprs.get(3).map_or(true, |e| e.is_literal()) =>
98 {
99 let pattern = exprs[1].as_literal_str().unwrap();
100 let flags = exprs
101 .get(3)
102 .map_or("", |expr| expr.as_literal_str().unwrap());
103 let (limit, flags) = regexp_replace_parse_flags(flags);
104
105 *e = match func::build_regex(pattern, &flags) {
111 Ok(regex) => {
112 let mut exprs = mem::take(exprs);
113 let replacement = exprs.swap_remove(2);
114 let source = exprs.swap_remove(0);
115 source.call_binary(
116 replacement,
117 BinaryFunc::from(func::RegexpReplace { regex, limit }),
118 )
119 }
120 Err(err) => {
121 let mut exprs = mem::take(exprs);
122 let replacement = exprs.swap_remove(2);
123 let source = exprs.swap_remove(0);
124 let scalar_type = e.typ(column_types).scalar_type;
125 source
127 .call_is_null()
128 .or(replacement.call_is_null())
129 .if_then_else(
130 MirScalarExpr::literal_null(scalar_type.clone()),
131 MirScalarExpr::literal(Err(err), scalar_type),
132 )
133 }
134 };
135 }
136 VariadicFunc::RegexpSplitToArray(_)
137 if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
138 {
139 let needle = exprs[1].as_literal_str().unwrap();
140 let flags = if exprs.len() == 3 {
141 exprs[2].as_literal_str().unwrap()
142 } else {
143 ""
144 };
145 *e = match func::build_regex(needle, flags) {
146 Ok(regex) => {
147 mem::take(exprs)
148 .into_first()
149 .call_unary(UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(
150 regex,
151 )))
152 }
153 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
154 };
155 }
156 VariadicFunc::ListIndex(_) if is_list_create_call(&exprs[0]) => {
157 let ind_exprs = exprs.split_off(1);
160 let top_list_create = exprs.swap_remove(0);
161 *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
162 }
163 VariadicFunc::And(_) | VariadicFunc::Or(_) => {
164 e.undistribute_and_or();
166 e.reduce_and_canonicalize_and_or();
167 }
168 VariadicFunc::TimezoneTimeVariadic(_)
169 if exprs[0].is_literal() && exprs[2].is_literal_ok() =>
170 {
171 let tz = exprs[0].as_literal_str().unwrap();
172 *e = match parse_timezone(tz, TimezoneSpec::Posix) {
173 Ok(tz) => MirScalarExpr::CallUnary {
174 func: UnaryFunc::TimezoneTime(func::TimezoneTime {
175 tz,
176 wall_time: exprs[2]
177 .as_literal()
178 .unwrap()
179 .unwrap()
180 .unwrap_timestamptz()
181 .naive_utc(),
182 }),
183 expr: Box::new(exprs[1].take()),
184 },
185 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
186 };
187 }
188 _ => {}
189 }
190}
191
192fn reduce_greatest_least(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
206 let typ = e.typ(column_types).scalar_type;
207 let MirScalarExpr::CallVariadic { exprs, .. } = e else {
208 unreachable!()
209 };
210 let mut seen = BTreeSet::new();
211 exprs.retain(|x| seen.insert(x.clone()));
212 exprs.retain(|x| !x.is_literal_null());
213 match exprs.len() {
214 0 => *e = MirScalarExpr::literal_null(typ),
215 1 => *e = exprs.swap_remove(0),
216 _ => {}
217 }
218}
219
220fn simplify_coalesce(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
228 let MirScalarExpr::CallVariadic { exprs, .. } = e else {
229 unreachable!()
230 };
231
232 if exprs.iter().all(|x| x.is_literal_null()) {
236 *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
237 return;
238 }
239
240 exprs.retain(|x| !x.is_literal_null());
242
243 if let Some(i) = exprs
248 .iter()
249 .position(|x| x.is_literal() || !x.typ(column_types).nullable)
250 {
251 exprs.truncate(i + 1);
252 }
253
254 let mut seen = BTreeSet::new();
256 exprs.retain(|x| seen.insert(x.clone()));
257
258 if exprs.len() == 1 {
259 *e = exprs[0].take();
261 }
262}
263
264fn is_list_create_call(expr: &MirScalarExpr) -> bool {
265 matches!(
266 expr,
267 MirScalarExpr::CallVariadic {
268 func: VariadicFunc::ListCreate(..),
269 ..
270 }
271 )
272}
273
274fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
275 if let MirScalarExpr::CallVariadic {
276 func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
277 ..
278 } = list_create
279 {
280 ReprScalarType::from(typ)
281 } else {
282 unreachable!()
283 }
284}
285
286fn reduce_list_create_list_index_literal(
301 mut list_create_to_reduce: MirScalarExpr,
302 mut index_exprs: Vec<MirScalarExpr>,
303) -> MirScalarExpr {
304 let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
316 let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
317 let mut i = 0;
318 while i < index_exprs.len()
319 && list_create_mut_refs
320 .iter()
321 .all(|lc| is_list_create_call(lc))
322 {
323 if index_exprs[i].is_literal_ok() {
324 let removed_index = index_exprs.remove(i);
326 let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
327 Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
328 _ => unreachable!(), };
330 for list_create in &mut list_create_mut_refs {
333 let list_create_args = match list_create {
334 MirScalarExpr::CallVariadic {
335 func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
336 exprs,
337 } => exprs,
338 _ => unreachable!(), };
340 if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap() {
342 let index: usize = index_i64.try_into().unwrap();
343 **list_create = list_create_args.swap_remove(index);
344 } else {
345 let typ = list_create_type(list_create);
346 **list_create = MirScalarExpr::literal_null(typ);
347 }
348 }
349 for t in earlier_list_create_types.iter_mut() {
351 if let SqlScalarType::List {
352 element_type,
353 custom_id: _,
354 } = t
355 {
356 **t = *element_type.clone();
357 let mut u = &mut **t;
360 while let SqlScalarType::List {
361 element_type,
362 custom_id,
363 } = u
364 {
365 *custom_id = None;
366 u = &mut **element_type;
367 }
368 } else {
369 unreachable!("already matched below");
370 }
371 }
372 } else {
373 list_create_mut_refs = list_create_mut_refs
377 .into_iter()
378 .flat_map(|list_create| match list_create {
379 MirScalarExpr::CallVariadic {
380 func: VariadicFunc::ListCreate(ListCreate { elem_type }),
381 exprs: list_create_args,
382 } => {
383 earlier_list_create_types.push(elem_type);
384 list_create_args
385 }
386 _ => unreachable!(),
388 })
389 .collect();
390 i += 1;
391 }
392 }
393 if index_exprs.is_empty() {
396 assert_eq!(list_create_mut_refs.len(), 1);
397 list_create_to_reduce
398 } else {
399 MirScalarExpr::call_variadic(
400 ListIndex,
401 std::iter::once(list_create_to_reduce)
402 .chain(index_exprs)
403 .collect(),
404 )
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use mz_repr::{Datum, ReprScalarType};
411
412 use crate::MirScalarExpr;
413 use crate::scalar::func::variadic::{Greatest, Least, Substr};
414
415 #[mz_ore::test]
416 fn greatest_least_null_operand_drop() {
417 let types = [
418 ReprScalarType::Int32.nullable(true),
419 ReprScalarType::Int32.nullable(true),
420 ];
421 let null = || MirScalarExpr::literal_null(ReprScalarType::Int32);
422 let col = MirScalarExpr::column;
423
424 let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), null()]);
426 e.reduce(&types);
427 assert_eq!(e, col(0));
428
429 let mut e = MirScalarExpr::call_variadic(Least, vec![col(0), null(), col(1)]);
430 e.reduce(&types);
431 assert_eq!(e, MirScalarExpr::call_variadic(Least, vec![col(0), col(1)]));
432
433 let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), col(0)]);
437 e.reduce(&types);
438 assert_eq!(e, col(0));
439
440 let mut e = MirScalarExpr::call_variadic(Least, vec![col(1), col(0), col(1)]);
441 e.reduce(&types);
442 assert_eq!(e, MirScalarExpr::call_variadic(Least, vec![col(1), col(0)]));
443
444 let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), null(), col(0)]);
446 e.reduce(&types);
447 assert_eq!(e, col(0));
448
449 let mut e = MirScalarExpr::call_variadic(Greatest, vec![null(), null()]);
451 e.reduce(&types);
452 assert!(e.is_literal_null());
453 }
454
455 #[mz_ore::test]
456 fn substr_from_one() {
457 let types = [ReprScalarType::String.nullable(true)];
458 let col = || MirScalarExpr::column(0);
459 let lit = |v| MirScalarExpr::literal_ok(Datum::Int32(v), ReprScalarType::Int32);
460
461 let mut e = MirScalarExpr::call_variadic(Substr, vec![col(), lit(1)]);
463 e.reduce(&types);
464 assert_eq!(e, col());
465
466 let mut e = MirScalarExpr::call_variadic(Substr, vec![col(), lit(1), lit(5)]);
468 e.reduce(&types);
469 assert_ne!(e, col());
470 }
471}