Skip to main content

mz_expr/scalar/reduce/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallVariadic` nodes.
11
12use std::collections::BTreeSet;
13use std::mem;
14
15use mz_ore::collections::CollectionExt;
16use mz_pgtz::timezone::TimezoneSpec;
17use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena, SqlScalarType};
18
19use crate::scalar::func::variadic::{Coalesce, ListCreate, ListIndex};
20use crate::scalar::func::{
21    self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone, regexp_replace_parse_flags,
22};
23use crate::{Eval, MirScalarExpr};
24
25pub(super) fn reduce_call_variadic(
26    e: &mut MirScalarExpr,
27    column_types: &[ReprColumnType],
28    temp_storage: &RowArena,
29) {
30    // Flatten chains of associative variadic calls before any per-`func`
31    // dispatch. `undistribute_and_or` below relies on this having run.
32    e.flatten_associative();
33
34    let MirScalarExpr::CallVariadic { func, exprs } = e else {
35        unreachable!("`flatten_associative` shouldn't change node type");
36    };
37
38    // Coalesce has its own simplification routine that handles null/error
39    // propagation internally — bail out to it.
40    if *func == Coalesce.into() {
41        simplify_coalesce(e, column_types);
42        return;
43    }
44
45    // Generic folds: constant-fold, null-propagate, error-propagate.
46    if exprs.iter().all(|x| x.is_literal()) {
47        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
48        return;
49    }
50    if func.propagates_nulls() && exprs.iter().any(|x| x.is_literal_null()) {
51        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
52        return;
53    }
54    // Only a strict (null-propagating) function propagates an operand's error
55    // unconditionally. A non-strict function such as AND/OR has a dominating
56    // operand (`false`/`true`) that absorbs another operand's error at runtime
57    // — e.g. `false AND <error>` evaluates to `false` — so folding the whole
58    // call to the error here would introduce an error the evaluated expression
59    // never raises. (Coalesce, also non-strict, bailed out above.)
60    if func.propagates_nulls() {
61        if let Some(err) = exprs.iter().find_map(|x| x.as_literal_err()) {
62            *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
63            return;
64        }
65    }
66
67    // Per-function dispatch. Arms are mutually exclusive on discriminant; the
68    // bodies only fire when their literal-argument guards hold.
69    match func {
70        VariadicFunc::Greatest(_) | VariadicFunc::Least(_) => {
71            reduce_greatest_least(e, column_types);
72        }
73        VariadicFunc::Substr(_)
74            if exprs.len() == 2 && matches!(exprs[1].as_literal(), Some(Ok(Datum::Int32(1)))) =>
75        {
76            // `substr(s, 1)` — the two-argument form — keeps the entire
77            // string, and its evaluation at a start of one is infallible.
78            *e = exprs.swap_remove(0);
79        }
80        VariadicFunc::RegexpMatch(_)
81            if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
82        {
83            let needle = exprs[1].as_literal_str().unwrap();
84            let flags = if exprs.len() == 3 {
85                exprs[2].as_literal_str().unwrap()
86            } else {
87                ""
88            };
89            *e = match func::build_regex(needle, flags) {
90                Ok(regex) => mem::take(exprs)
91                    .into_first()
92                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
93                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
94            };
95        }
96        VariadicFunc::RegexpReplace(_)
97            if exprs[1].is_literal() && exprs.get(3).map_or(true, |e| e.is_literal()) =>
98        {
99            let pattern = exprs[1].as_literal_str().unwrap();
100            let flags = exprs
101                .get(3)
102                .map_or("", |expr| expr.as_literal_str().unwrap());
103            let (limit, flags) = regexp_replace_parse_flags(flags);
104
105            // The behavior of `regexp_replace` is that if the data is `NULL`, the
106            // function returns `NULL`, independently of whether the pattern or
107            // flags are correct. We need to check for this case and introduce an
108            // if-then-else on the error path to only surface the error if both
109            // the source and replacement inputs are non-NULL.
110            *e = match func::build_regex(pattern, &flags) {
111                Ok(regex) => {
112                    let mut exprs = mem::take(exprs);
113                    let replacement = exprs.swap_remove(2);
114                    let source = exprs.swap_remove(0);
115                    source.call_binary(
116                        replacement,
117                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
118                    )
119                }
120                Err(err) => {
121                    let mut exprs = mem::take(exprs);
122                    let replacement = exprs.swap_remove(2);
123                    let source = exprs.swap_remove(0);
124                    let scalar_type = e.typ(column_types).scalar_type;
125                    // We need to return `NULL` on `NULL` input, and error otherwise.
126                    source
127                        .call_is_null()
128                        .or(replacement.call_is_null())
129                        .if_then_else(
130                            MirScalarExpr::literal_null(scalar_type.clone()),
131                            MirScalarExpr::literal(Err(err), scalar_type),
132                        )
133                }
134            };
135        }
136        VariadicFunc::RegexpSplitToArray(_)
137            if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
138        {
139            let needle = exprs[1].as_literal_str().unwrap();
140            let flags = if exprs.len() == 3 {
141                exprs[2].as_literal_str().unwrap()
142            } else {
143                ""
144            };
145            *e = match func::build_regex(needle, flags) {
146                Ok(regex) => {
147                    mem::take(exprs)
148                        .into_first()
149                        .call_unary(UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(
150                            regex,
151                        )))
152                }
153                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
154            };
155        }
156        VariadicFunc::ListIndex(_) if is_list_create_call(&exprs[0]) => {
157            // We are looking for ListIndex(ListCreate, literal), and eliminate
158            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
159            let ind_exprs = exprs.split_off(1);
160            let top_list_create = exprs.swap_remove(0);
161            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
162        }
163        VariadicFunc::And(_) | VariadicFunc::Or(_) => {
164            // Note: It's important that we have called `flatten_associative` above.
165            e.undistribute_and_or();
166            e.reduce_and_canonicalize_and_or();
167        }
168        VariadicFunc::TimezoneTimeVariadic(_)
169            if exprs[0].is_literal() && exprs[2].is_literal_ok() =>
170        {
171            let tz = exprs[0].as_literal_str().unwrap();
172            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
173                Ok(tz) => MirScalarExpr::CallUnary {
174                    func: UnaryFunc::TimezoneTime(func::TimezoneTime {
175                        tz,
176                        wall_time: exprs[2]
177                            .as_literal()
178                            .unwrap()
179                            .unwrap()
180                            .unwrap_timestamptz()
181                            .naive_utc(),
182                    }),
183                    expr: Box::new(exprs[1].take()),
184                },
185                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
186            };
187        }
188        _ => {}
189    }
190}
191
192/// Simplifies a `Greatest`/`Least` call:
193/// 1. Deduplicate structurally equal operands, keeping first occurrences.
194///    Scalar evaluation is deterministic (the `And`/`Or` and `Coalesce`
195///    reducers already rely on this when they deduplicate), so duplicates of
196///    an expression contribute the same value — over which `greatest`/`least`
197///    are idempotent — or the same error; keeping the first occurrence leaves
198///    unchanged which error surfaces.
199/// 2. Drop literal null operands: both functions ignore null inputs (they
200///    return the max/min of the non-null inputs, and null only when every
201///    input is null).
202/// 3. A call left with a single operand is the identity on it — the call
203///    evaluates the operand once and returns it, null or not — and a call
204///    left with none is null.
205fn reduce_greatest_least(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
206    let typ = e.typ(column_types).scalar_type;
207    let MirScalarExpr::CallVariadic { exprs, .. } = e else {
208        unreachable!()
209    };
210    let mut seen = BTreeSet::new();
211    exprs.retain(|x| seen.insert(x.clone()));
212    exprs.retain(|x| !x.is_literal_null());
213    match exprs.len() {
214        0 => *e = MirScalarExpr::literal_null(typ),
215        1 => *e = exprs.swap_remove(0),
216        _ => {}
217    }
218}
219
220/// Simplifies a `Coalesce`:
221/// 1. If all arguments are null, the result is null.
222/// 2. Drop null arguments (none of them can be the result).
223/// 3. Truncate after the first argument known to be non-null (a literal or a
224///    non-nullable column).
225/// 4. Deduplicate arguments (e.g. `coalesce(#0, #0) → coalesce(#0)`).
226/// 5. Unwrap a single-argument `coalesce`.
227fn simplify_coalesce(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
228    let MirScalarExpr::CallVariadic { exprs, .. } = e else {
229        unreachable!()
230    };
231
232    // If all inputs are null, output is null. This check must
233    // be done before `exprs.retain...` because `e.typ` requires
234    // > 0 `exprs` remain.
235    if exprs.iter().all(|x| x.is_literal_null()) {
236        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
237        return;
238    }
239
240    // Remove any null values if not all values are null.
241    exprs.retain(|x| !x.is_literal_null());
242
243    // Find the first argument that is a literal or non-nullable
244    // column. All arguments after it get ignored, so throw them
245    // away. This intentionally throws away errors that can
246    // never happen.
247    if let Some(i) = exprs
248        .iter()
249        .position(|x| x.is_literal() || !x.typ(column_types).nullable)
250    {
251        exprs.truncate(i + 1);
252    }
253
254    // Deduplicate arguments in cases like `coalesce(#0, #0)`.
255    let mut seen = BTreeSet::new();
256    exprs.retain(|x| seen.insert(x.clone()));
257
258    if exprs.len() == 1 {
259        // Only one argument, so the coalesce is a no-op.
260        *e = exprs[0].take();
261    }
262}
263
264fn is_list_create_call(expr: &MirScalarExpr) -> bool {
265    matches!(
266        expr,
267        MirScalarExpr::CallVariadic {
268            func: VariadicFunc::ListCreate(..),
269            ..
270        }
271    )
272}
273
274fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
275    if let MirScalarExpr::CallVariadic {
276        func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
277        ..
278    } = list_create
279    {
280        ReprScalarType::from(typ)
281    } else {
282        unreachable!()
283    }
284}
285
286/// Partial-evaluates a list indexing with a literal directly after a list creation.
287///
288/// Multi-dimensional lists are handled by a single call to this function, with multiple
289/// elements in index_exprs (of which not all need to be literals), and nested ListCreates
290/// in list_create_to_reduce.
291///
292/// # Examples
293///
294/// `LIST[f1,f2][2]` --> `f2`.
295///
296/// A multi-dimensional list, with only some of the indexes being literals:
297/// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
298///
299/// See more examples in list.slt.
300fn reduce_list_create_list_index_literal(
301    mut list_create_to_reduce: MirScalarExpr,
302    mut index_exprs: Vec<MirScalarExpr>,
303) -> MirScalarExpr {
304    // We iterate over the index_exprs and remove literals, but keep non-literals.
305    // When we encounter a non-literal, we need to dig into the nested ListCreates:
306    // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
307    // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
308    // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
309    // that are at the current level (except those that disappeared due to
310    // literals at earlier levels), index into them with the literal, and change each
311    // element in `list_create_mut_refs` to the result.
312    // We also record mut refs to all the earlier `element_type` references that we have
313    // seen in ListCreate calls, because when we process a literal index, we need to remove
314    // one layer of list type from all these earlier ListCreate `element_type`s.
315    let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
316    let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
317    let mut i = 0;
318    while i < index_exprs.len()
319        && list_create_mut_refs
320            .iter()
321            .all(|lc| is_list_create_call(lc))
322    {
323        if index_exprs[i].is_literal_ok() {
324            // We can remove this index.
325            let removed_index = index_exprs.remove(i);
326            let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
327                Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
328                _ => unreachable!(), // always an Int64, see plan_index_list
329            };
330            // For each list_create referenced by list_create_mut_refs, substitute it by its
331            // `index`th argument (or null).
332            for list_create in &mut list_create_mut_refs {
333                let list_create_args = match list_create {
334                    MirScalarExpr::CallVariadic {
335                        func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
336                        exprs,
337                    } => exprs,
338                    _ => unreachable!(), // func cannot be anything else than a ListCreate
339                };
340                // ListIndex gives null on an out-of-bounds index
341                if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap() {
342                    let index: usize = index_i64.try_into().unwrap();
343                    **list_create = list_create_args.swap_remove(index);
344                } else {
345                    let typ = list_create_type(list_create);
346                    **list_create = MirScalarExpr::literal_null(typ);
347                }
348            }
349            // Peel one layer off of each of the earlier element types.
350            for t in earlier_list_create_types.iter_mut() {
351                if let SqlScalarType::List {
352                    element_type,
353                    custom_id: _,
354                } = t
355                {
356                    **t = *element_type.clone();
357                    // These are not the same types anymore, so remove custom_ids all the
358                    // way down.
359                    let mut u = &mut **t;
360                    while let SqlScalarType::List {
361                        element_type,
362                        custom_id,
363                    } = u
364                    {
365                        *custom_id = None;
366                        u = &mut **element_type;
367                    }
368                } else {
369                    unreachable!("already matched below");
370                }
371            }
372        } else {
373            // We can't remove this index, so we can't reduce any of the ListCreates at this
374            // level. So we change list_create_mut_refs to refer to all the arguments of all
375            // the ListCreates currently referenced by list_create_mut_refs.
376            list_create_mut_refs = list_create_mut_refs
377                .into_iter()
378                .flat_map(|list_create| match list_create {
379                    MirScalarExpr::CallVariadic {
380                        func: VariadicFunc::ListCreate(ListCreate { elem_type }),
381                        exprs: list_create_args,
382                    } => {
383                        earlier_list_create_types.push(elem_type);
384                        list_create_args
385                    }
386                    // func cannot be anything else than a ListCreate
387                    _ => unreachable!(),
388                })
389                .collect();
390            i += 1;
391        }
392    }
393    // If all list indexes have been evaluated, return the reduced expression.
394    // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
395    if index_exprs.is_empty() {
396        assert_eq!(list_create_mut_refs.len(), 1);
397        list_create_to_reduce
398    } else {
399        MirScalarExpr::call_variadic(
400            ListIndex,
401            std::iter::once(list_create_to_reduce)
402                .chain(index_exprs)
403                .collect(),
404        )
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use mz_repr::{Datum, ReprScalarType};
411
412    use crate::MirScalarExpr;
413    use crate::scalar::func::variadic::{Greatest, Least, Substr};
414
415    #[mz_ore::test]
416    fn greatest_least_null_operand_drop() {
417        let types = [
418            ReprScalarType::Int32.nullable(true),
419            ReprScalarType::Int32.nullable(true),
420        ];
421        let null = || MirScalarExpr::literal_null(ReprScalarType::Int32);
422        let col = MirScalarExpr::column;
423
424        // Null operands drop; a single survivor is the result.
425        let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), null()]);
426        e.reduce(&types);
427        assert_eq!(e, col(0));
428
429        let mut e = MirScalarExpr::call_variadic(Least, vec![col(0), null(), col(1)]);
430        e.reduce(&types);
431        assert_eq!(e, MirScalarExpr::call_variadic(Least, vec![col(0), col(1)]));
432
433        // Structurally equal operands deduplicate (scalar evaluation is
434        // deterministic and greatest/least are idempotent), keeping first
435        // occurrences.
436        let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), col(0)]);
437        e.reduce(&types);
438        assert_eq!(e, col(0));
439
440        let mut e = MirScalarExpr::call_variadic(Least, vec![col(1), col(0), col(1)]);
441        e.reduce(&types);
442        assert_eq!(e, MirScalarExpr::call_variadic(Least, vec![col(1), col(0)]));
443
444        // Dedup and null-drop compose down to the bare operand.
445        let mut e = MirScalarExpr::call_variadic(Greatest, vec![col(0), null(), col(0)]);
446        e.reduce(&types);
447        assert_eq!(e, col(0));
448
449        // All-null evaluates to null.
450        let mut e = MirScalarExpr::call_variadic(Greatest, vec![null(), null()]);
451        e.reduce(&types);
452        assert!(e.is_literal_null());
453    }
454
455    #[mz_ore::test]
456    fn substr_from_one() {
457        let types = [ReprScalarType::String.nullable(true)];
458        let col = || MirScalarExpr::column(0);
459        let lit = |v| MirScalarExpr::literal_ok(Datum::Int32(v), ReprScalarType::Int32);
460
461        // The two-argument form starting at one is the identity.
462        let mut e = MirScalarExpr::call_variadic(Substr, vec![col(), lit(1)]);
463        e.reduce(&types);
464        assert_eq!(e, col());
465
466        // The three-argument form truncates and must stay.
467        let mut e = MirScalarExpr::call_variadic(Substr, vec![col(), lit(1), lit(5)]);
468        e.reduce(&types);
469        assert_ne!(e, col());
470    }
471}