mz_expr/scalar/reduce/variadic.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallVariadic` nodes.
11
12use std::collections::BTreeSet;
13use std::mem;
14
15use mz_ore::collections::CollectionExt;
16use mz_pgtz::timezone::TimezoneSpec;
17use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena, SqlScalarType};
18
19use crate::MirScalarExpr;
20use crate::scalar::func::variadic::{Coalesce, ListCreate, ListIndex};
21use crate::scalar::func::{
22 self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone, regexp_replace_parse_flags,
23};
24
25pub(super) fn reduce_call_variadic(
26 e: &mut MirScalarExpr,
27 column_types: &[ReprColumnType],
28 temp_storage: &RowArena,
29) {
30 // Flatten chains of associative variadic calls before any per-`func`
31 // dispatch. `undistribute_and_or` below relies on this having run.
32 e.flatten_associative();
33
34 let MirScalarExpr::CallVariadic { func, exprs } = e else {
35 unreachable!("`flatten_associative` shouldn't change node type");
36 };
37
38 // Coalesce has its own simplification routine that handles null/error
39 // propagation internally — bail out to it.
40 if *func == Coalesce.into() {
41 simplify_coalesce(e, column_types);
42 return;
43 }
44
45 // Generic folds: constant-fold, null-propagate, error-propagate.
46 if exprs.iter().all(|x| x.is_literal()) {
47 *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
48 return;
49 }
50 if func.propagates_nulls() && exprs.iter().any(|x| x.is_literal_null()) {
51 *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
52 return;
53 }
54 if let Some(err) = exprs.iter().find_map(|x| x.as_literal_err()) {
55 *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
56 return;
57 }
58
59 // Per-function dispatch. Arms are mutually exclusive on discriminant; the
60 // bodies only fire when their literal-argument guards hold.
61 match func {
62 VariadicFunc::RegexpMatch(_)
63 if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
64 {
65 let needle = exprs[1].as_literal_str().unwrap();
66 let flags = if exprs.len() == 3 {
67 exprs[2].as_literal_str().unwrap()
68 } else {
69 ""
70 };
71 *e = match func::build_regex(needle, flags) {
72 Ok(regex) => mem::take(exprs)
73 .into_first()
74 .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
75 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
76 };
77 }
78 VariadicFunc::RegexpReplace(_)
79 if exprs[1].is_literal() && exprs.get(3).map_or(true, |e| e.is_literal()) =>
80 {
81 let pattern = exprs[1].as_literal_str().unwrap();
82 let flags = exprs
83 .get(3)
84 .map_or("", |expr| expr.as_literal_str().unwrap());
85 let (limit, flags) = regexp_replace_parse_flags(flags);
86
87 // The behavior of `regexp_replace` is that if the data is `NULL`, the
88 // function returns `NULL`, independently of whether the pattern or
89 // flags are correct. We need to check for this case and introduce an
90 // if-then-else on the error path to only surface the error if both
91 // the source and replacement inputs are non-NULL.
92 *e = match func::build_regex(pattern, &flags) {
93 Ok(regex) => {
94 let mut exprs = mem::take(exprs);
95 let replacement = exprs.swap_remove(2);
96 let source = exprs.swap_remove(0);
97 source.call_binary(
98 replacement,
99 BinaryFunc::from(func::RegexpReplace { regex, limit }),
100 )
101 }
102 Err(err) => {
103 let mut exprs = mem::take(exprs);
104 let replacement = exprs.swap_remove(2);
105 let source = exprs.swap_remove(0);
106 let scalar_type = e.typ(column_types).scalar_type;
107 // We need to return `NULL` on `NULL` input, and error otherwise.
108 source
109 .call_is_null()
110 .or(replacement.call_is_null())
111 .if_then_else(
112 MirScalarExpr::literal_null(scalar_type.clone()),
113 MirScalarExpr::literal(Err(err), scalar_type),
114 )
115 }
116 };
117 }
118 VariadicFunc::RegexpSplitToArray(_)
119 if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
120 {
121 let needle = exprs[1].as_literal_str().unwrap();
122 let flags = if exprs.len() == 3 {
123 exprs[2].as_literal_str().unwrap()
124 } else {
125 ""
126 };
127 *e = match func::build_regex(needle, flags) {
128 Ok(regex) => {
129 mem::take(exprs)
130 .into_first()
131 .call_unary(UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(
132 regex,
133 )))
134 }
135 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
136 };
137 }
138 VariadicFunc::ListIndex(_) if is_list_create_call(&exprs[0]) => {
139 // We are looking for ListIndex(ListCreate, literal), and eliminate
140 // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
141 let ind_exprs = exprs.split_off(1);
142 let top_list_create = exprs.swap_remove(0);
143 *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
144 }
145 VariadicFunc::And(_) | VariadicFunc::Or(_) => {
146 // Note: It's important that we have called `flatten_associative` above.
147 e.undistribute_and_or();
148 e.reduce_and_canonicalize_and_or();
149 }
150 VariadicFunc::TimezoneTimeVariadic(_)
151 if exprs[0].is_literal() && exprs[2].is_literal_ok() =>
152 {
153 let tz = exprs[0].as_literal_str().unwrap();
154 *e = match parse_timezone(tz, TimezoneSpec::Posix) {
155 Ok(tz) => MirScalarExpr::CallUnary {
156 func: UnaryFunc::TimezoneTime(func::TimezoneTime {
157 tz,
158 wall_time: exprs[2]
159 .as_literal()
160 .unwrap()
161 .unwrap()
162 .unwrap_timestamptz()
163 .naive_utc(),
164 }),
165 expr: Box::new(exprs[1].take()),
166 },
167 Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
168 };
169 }
170 _ => {}
171 }
172}
173
174/// Simplifies a `Coalesce`:
175/// 1. If all arguments are null, the result is null.
176/// 2. Drop null arguments (none of them can be the result).
177/// 3. Truncate after the first argument known to be non-null (a literal or a
178/// non-nullable column).
179/// 4. Deduplicate arguments (e.g. `coalesce(#0, #0) → coalesce(#0)`).
180/// 5. Unwrap a single-argument `coalesce`.
181fn simplify_coalesce(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
182 let MirScalarExpr::CallVariadic { exprs, .. } = e else {
183 unreachable!()
184 };
185
186 // If all inputs are null, output is null. This check must
187 // be done before `exprs.retain...` because `e.typ` requires
188 // > 0 `exprs` remain.
189 if exprs.iter().all(|x| x.is_literal_null()) {
190 *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
191 return;
192 }
193
194 // Remove any null values if not all values are null.
195 exprs.retain(|x| !x.is_literal_null());
196
197 // Find the first argument that is a literal or non-nullable
198 // column. All arguments after it get ignored, so throw them
199 // away. This intentionally throws away errors that can
200 // never happen.
201 if let Some(i) = exprs
202 .iter()
203 .position(|x| x.is_literal() || !x.typ(column_types).nullable)
204 {
205 exprs.truncate(i + 1);
206 }
207
208 // Deduplicate arguments in cases like `coalesce(#0, #0)`.
209 let mut seen = BTreeSet::new();
210 exprs.retain(|x| seen.insert(x.clone()));
211
212 if exprs.len() == 1 {
213 // Only one argument, so the coalesce is a no-op.
214 *e = exprs[0].take();
215 }
216}
217
218fn is_list_create_call(expr: &MirScalarExpr) -> bool {
219 matches!(
220 expr,
221 MirScalarExpr::CallVariadic {
222 func: VariadicFunc::ListCreate(..),
223 ..
224 }
225 )
226}
227
228fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
229 if let MirScalarExpr::CallVariadic {
230 func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
231 ..
232 } = list_create
233 {
234 ReprScalarType::from(typ)
235 } else {
236 unreachable!()
237 }
238}
239
240/// Partial-evaluates a list indexing with a literal directly after a list creation.
241///
242/// Multi-dimensional lists are handled by a single call to this function, with multiple
243/// elements in index_exprs (of which not all need to be literals), and nested ListCreates
244/// in list_create_to_reduce.
245///
246/// # Examples
247///
248/// `LIST[f1,f2][2]` --> `f2`.
249///
250/// A multi-dimensional list, with only some of the indexes being literals:
251/// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
252///
253/// See more examples in list.slt.
254fn reduce_list_create_list_index_literal(
255 mut list_create_to_reduce: MirScalarExpr,
256 mut index_exprs: Vec<MirScalarExpr>,
257) -> MirScalarExpr {
258 // We iterate over the index_exprs and remove literals, but keep non-literals.
259 // When we encounter a non-literal, we need to dig into the nested ListCreates:
260 // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
261 // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
262 // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
263 // that are at the current level (except those that disappeared due to
264 // literals at earlier levels), index into them with the literal, and change each
265 // element in `list_create_mut_refs` to the result.
266 // We also record mut refs to all the earlier `element_type` references that we have
267 // seen in ListCreate calls, because when we process a literal index, we need to remove
268 // one layer of list type from all these earlier ListCreate `element_type`s.
269 let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
270 let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
271 let mut i = 0;
272 while i < index_exprs.len()
273 && list_create_mut_refs
274 .iter()
275 .all(|lc| is_list_create_call(lc))
276 {
277 if index_exprs[i].is_literal_ok() {
278 // We can remove this index.
279 let removed_index = index_exprs.remove(i);
280 let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
281 Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
282 _ => unreachable!(), // always an Int64, see plan_index_list
283 };
284 // For each list_create referenced by list_create_mut_refs, substitute it by its
285 // `index`th argument (or null).
286 for list_create in &mut list_create_mut_refs {
287 let list_create_args = match list_create {
288 MirScalarExpr::CallVariadic {
289 func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
290 exprs,
291 } => exprs,
292 _ => unreachable!(), // func cannot be anything else than a ListCreate
293 };
294 // ListIndex gives null on an out-of-bounds index
295 if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap() {
296 let index: usize = index_i64.try_into().unwrap();
297 **list_create = list_create_args.swap_remove(index);
298 } else {
299 let typ = list_create_type(list_create);
300 **list_create = MirScalarExpr::literal_null(typ);
301 }
302 }
303 // Peel one layer off of each of the earlier element types.
304 for t in earlier_list_create_types.iter_mut() {
305 if let SqlScalarType::List {
306 element_type,
307 custom_id: _,
308 } = t
309 {
310 **t = *element_type.clone();
311 // These are not the same types anymore, so remove custom_ids all the
312 // way down.
313 let mut u = &mut **t;
314 while let SqlScalarType::List {
315 element_type,
316 custom_id,
317 } = u
318 {
319 *custom_id = None;
320 u = &mut **element_type;
321 }
322 } else {
323 unreachable!("already matched below");
324 }
325 }
326 } else {
327 // We can't remove this index, so we can't reduce any of the ListCreates at this
328 // level. So we change list_create_mut_refs to refer to all the arguments of all
329 // the ListCreates currently referenced by list_create_mut_refs.
330 list_create_mut_refs = list_create_mut_refs
331 .into_iter()
332 .flat_map(|list_create| match list_create {
333 MirScalarExpr::CallVariadic {
334 func: VariadicFunc::ListCreate(ListCreate { elem_type }),
335 exprs: list_create_args,
336 } => {
337 earlier_list_create_types.push(elem_type);
338 list_create_args
339 }
340 // func cannot be anything else than a ListCreate
341 _ => unreachable!(),
342 })
343 .collect();
344 i += 1;
345 }
346 }
347 // If all list indexes have been evaluated, return the reduced expression.
348 // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
349 if index_exprs.is_empty() {
350 assert_eq!(list_create_mut_refs.len(), 1);
351 list_create_to_reduce
352 } else {
353 MirScalarExpr::call_variadic(
354 ListIndex,
355 std::iter::once(list_create_to_reduce)
356 .chain(index_exprs)
357 .collect(),
358 )
359 }
360}