Skip to main content

mz_expr/scalar/reduce/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallVariadic` nodes.
11
12use std::collections::BTreeSet;
13use std::mem;
14
15use mz_ore::collections::CollectionExt;
16use mz_pgtz::timezone::TimezoneSpec;
17use mz_repr::{Datum, ReprColumnType, ReprScalarType, RowArena, SqlScalarType};
18
19use crate::MirScalarExpr;
20use crate::scalar::func::variadic::{Coalesce, ListCreate, ListIndex};
21use crate::scalar::func::{
22    self, BinaryFunc, UnaryFunc, VariadicFunc, parse_timezone, regexp_replace_parse_flags,
23};
24
25pub(super) fn reduce_call_variadic(
26    e: &mut MirScalarExpr,
27    column_types: &[ReprColumnType],
28    temp_storage: &RowArena,
29) {
30    // Flatten chains of associative variadic calls before any per-`func`
31    // dispatch. `undistribute_and_or` below relies on this having run.
32    e.flatten_associative();
33
34    let MirScalarExpr::CallVariadic { func, exprs } = e else {
35        unreachable!("`flatten_associative` shouldn't change node type");
36    };
37
38    // Coalesce has its own simplification routine that handles null/error
39    // propagation internally — bail out to it.
40    if *func == Coalesce.into() {
41        simplify_coalesce(e, column_types);
42        return;
43    }
44
45    // Generic folds: constant-fold, null-propagate, error-propagate.
46    if exprs.iter().all(|x| x.is_literal()) {
47        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
48        return;
49    }
50    if func.propagates_nulls() && exprs.iter().any(|x| x.is_literal_null()) {
51        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
52        return;
53    }
54    if let Some(err) = exprs.iter().find_map(|x| x.as_literal_err()) {
55        *e = MirScalarExpr::literal(Err(err.clone()), e.typ(column_types).scalar_type);
56        return;
57    }
58
59    // Per-function dispatch. Arms are mutually exclusive on discriminant; the
60    // bodies only fire when their literal-argument guards hold.
61    match func {
62        VariadicFunc::RegexpMatch(_)
63            if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
64        {
65            let needle = exprs[1].as_literal_str().unwrap();
66            let flags = if exprs.len() == 3 {
67                exprs[2].as_literal_str().unwrap()
68            } else {
69                ""
70            };
71            *e = match func::build_regex(needle, flags) {
72                Ok(regex) => mem::take(exprs)
73                    .into_first()
74                    .call_unary(UnaryFunc::RegexpMatch(func::RegexpMatch(regex))),
75                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
76            };
77        }
78        VariadicFunc::RegexpReplace(_)
79            if exprs[1].is_literal() && exprs.get(3).map_or(true, |e| e.is_literal()) =>
80        {
81            let pattern = exprs[1].as_literal_str().unwrap();
82            let flags = exprs
83                .get(3)
84                .map_or("", |expr| expr.as_literal_str().unwrap());
85            let (limit, flags) = regexp_replace_parse_flags(flags);
86
87            // The behavior of `regexp_replace` is that if the data is `NULL`, the
88            // function returns `NULL`, independently of whether the pattern or
89            // flags are correct. We need to check for this case and introduce an
90            // if-then-else on the error path to only surface the error if both
91            // the source and replacement inputs are non-NULL.
92            *e = match func::build_regex(pattern, &flags) {
93                Ok(regex) => {
94                    let mut exprs = mem::take(exprs);
95                    let replacement = exprs.swap_remove(2);
96                    let source = exprs.swap_remove(0);
97                    source.call_binary(
98                        replacement,
99                        BinaryFunc::from(func::RegexpReplace { regex, limit }),
100                    )
101                }
102                Err(err) => {
103                    let mut exprs = mem::take(exprs);
104                    let replacement = exprs.swap_remove(2);
105                    let source = exprs.swap_remove(0);
106                    let scalar_type = e.typ(column_types).scalar_type;
107                    // We need to return `NULL` on `NULL` input, and error otherwise.
108                    source
109                        .call_is_null()
110                        .or(replacement.call_is_null())
111                        .if_then_else(
112                            MirScalarExpr::literal_null(scalar_type.clone()),
113                            MirScalarExpr::literal(Err(err), scalar_type),
114                        )
115                }
116            };
117        }
118        VariadicFunc::RegexpSplitToArray(_)
119            if exprs[1].is_literal() && exprs.get(2).map_or(true, |e| e.is_literal()) =>
120        {
121            let needle = exprs[1].as_literal_str().unwrap();
122            let flags = if exprs.len() == 3 {
123                exprs[2].as_literal_str().unwrap()
124            } else {
125                ""
126            };
127            *e = match func::build_regex(needle, flags) {
128                Ok(regex) => {
129                    mem::take(exprs)
130                        .into_first()
131                        .call_unary(UnaryFunc::RegexpSplitToArray(func::RegexpSplitToArray(
132                            regex,
133                        )))
134                }
135                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
136            };
137        }
138        VariadicFunc::ListIndex(_) if is_list_create_call(&exprs[0]) => {
139            // We are looking for ListIndex(ListCreate, literal), and eliminate
140            // both the ListIndex and the ListCreate. E.g.: `LIST[f1,f2][2]` --> `f2`
141            let ind_exprs = exprs.split_off(1);
142            let top_list_create = exprs.swap_remove(0);
143            *e = reduce_list_create_list_index_literal(top_list_create, ind_exprs);
144        }
145        VariadicFunc::And(_) | VariadicFunc::Or(_) => {
146            // Note: It's important that we have called `flatten_associative` above.
147            e.undistribute_and_or();
148            e.reduce_and_canonicalize_and_or();
149        }
150        VariadicFunc::TimezoneTimeVariadic(_)
151            if exprs[0].is_literal() && exprs[2].is_literal_ok() =>
152        {
153            let tz = exprs[0].as_literal_str().unwrap();
154            *e = match parse_timezone(tz, TimezoneSpec::Posix) {
155                Ok(tz) => MirScalarExpr::CallUnary {
156                    func: UnaryFunc::TimezoneTime(func::TimezoneTime {
157                        tz,
158                        wall_time: exprs[2]
159                            .as_literal()
160                            .unwrap()
161                            .unwrap()
162                            .unwrap_timestamptz()
163                            .naive_utc(),
164                    }),
165                    expr: Box::new(exprs[1].take()),
166                },
167                Err(err) => MirScalarExpr::literal(Err(err), e.typ(column_types).scalar_type),
168            };
169        }
170        _ => {}
171    }
172}
173
174/// Simplifies a `Coalesce`:
175/// 1. If all arguments are null, the result is null.
176/// 2. Drop null arguments (none of them can be the result).
177/// 3. Truncate after the first argument known to be non-null (a literal or a
178///    non-nullable column).
179/// 4. Deduplicate arguments (e.g. `coalesce(#0, #0) → coalesce(#0)`).
180/// 5. Unwrap a single-argument `coalesce`.
181fn simplify_coalesce(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
182    let MirScalarExpr::CallVariadic { exprs, .. } = e else {
183        unreachable!()
184    };
185
186    // If all inputs are null, output is null. This check must
187    // be done before `exprs.retain...` because `e.typ` requires
188    // > 0 `exprs` remain.
189    if exprs.iter().all(|x| x.is_literal_null()) {
190        *e = MirScalarExpr::literal_null(e.typ(column_types).scalar_type);
191        return;
192    }
193
194    // Remove any null values if not all values are null.
195    exprs.retain(|x| !x.is_literal_null());
196
197    // Find the first argument that is a literal or non-nullable
198    // column. All arguments after it get ignored, so throw them
199    // away. This intentionally throws away errors that can
200    // never happen.
201    if let Some(i) = exprs
202        .iter()
203        .position(|x| x.is_literal() || !x.typ(column_types).nullable)
204    {
205        exprs.truncate(i + 1);
206    }
207
208    // Deduplicate arguments in cases like `coalesce(#0, #0)`.
209    let mut seen = BTreeSet::new();
210    exprs.retain(|x| seen.insert(x.clone()));
211
212    if exprs.len() == 1 {
213        // Only one argument, so the coalesce is a no-op.
214        *e = exprs[0].take();
215    }
216}
217
218fn is_list_create_call(expr: &MirScalarExpr) -> bool {
219    matches!(
220        expr,
221        MirScalarExpr::CallVariadic {
222            func: VariadicFunc::ListCreate(..),
223            ..
224        }
225    )
226}
227
228fn list_create_type(list_create: &MirScalarExpr) -> ReprScalarType {
229    if let MirScalarExpr::CallVariadic {
230        func: VariadicFunc::ListCreate(ListCreate { elem_type: typ }),
231        ..
232    } = list_create
233    {
234        ReprScalarType::from(typ)
235    } else {
236        unreachable!()
237    }
238}
239
240/// Partial-evaluates a list indexing with a literal directly after a list creation.
241///
242/// Multi-dimensional lists are handled by a single call to this function, with multiple
243/// elements in index_exprs (of which not all need to be literals), and nested ListCreates
244/// in list_create_to_reduce.
245///
246/// # Examples
247///
248/// `LIST[f1,f2][2]` --> `f2`.
249///
250/// A multi-dimensional list, with only some of the indexes being literals:
251/// `LIST[[[f1, f2], [f3, f4]], [[f5, f6], [f7, f8]]] [2][n][2]` --> `LIST[f6, f8] [n]`
252///
253/// See more examples in list.slt.
254fn reduce_list_create_list_index_literal(
255    mut list_create_to_reduce: MirScalarExpr,
256    mut index_exprs: Vec<MirScalarExpr>,
257) -> MirScalarExpr {
258    // We iterate over the index_exprs and remove literals, but keep non-literals.
259    // When we encounter a non-literal, we need to dig into the nested ListCreates:
260    // `list_create_mut_refs` will contain all the ListCreates of the current level. If an
261    // element of `list_create_mut_refs` is not actually a ListCreate, then we break out of
262    // the loop. When we remove a literal, we need to partial-evaluate all ListCreates
263    // that are at the current level (except those that disappeared due to
264    // literals at earlier levels), index into them with the literal, and change each
265    // element in `list_create_mut_refs` to the result.
266    // We also record mut refs to all the earlier `element_type` references that we have
267    // seen in ListCreate calls, because when we process a literal index, we need to remove
268    // one layer of list type from all these earlier ListCreate `element_type`s.
269    let mut list_create_mut_refs = vec![&mut list_create_to_reduce];
270    let mut earlier_list_create_types: Vec<&mut SqlScalarType> = vec![];
271    let mut i = 0;
272    while i < index_exprs.len()
273        && list_create_mut_refs
274            .iter()
275            .all(|lc| is_list_create_call(lc))
276    {
277        if index_exprs[i].is_literal_ok() {
278            // We can remove this index.
279            let removed_index = index_exprs.remove(i);
280            let index_i64 = match removed_index.as_literal().unwrap().unwrap() {
281                Datum::Int64(sql_index_i64) => sql_index_i64 - 1,
282                _ => unreachable!(), // always an Int64, see plan_index_list
283            };
284            // For each list_create referenced by list_create_mut_refs, substitute it by its
285            // `index`th argument (or null).
286            for list_create in &mut list_create_mut_refs {
287                let list_create_args = match list_create {
288                    MirScalarExpr::CallVariadic {
289                        func: VariadicFunc::ListCreate(ListCreate { elem_type: _ }),
290                        exprs,
291                    } => exprs,
292                    _ => unreachable!(), // func cannot be anything else than a ListCreate
293                };
294                // ListIndex gives null on an out-of-bounds index
295                if index_i64 >= 0 && index_i64 < list_create_args.len().try_into().unwrap() {
296                    let index: usize = index_i64.try_into().unwrap();
297                    **list_create = list_create_args.swap_remove(index);
298                } else {
299                    let typ = list_create_type(list_create);
300                    **list_create = MirScalarExpr::literal_null(typ);
301                }
302            }
303            // Peel one layer off of each of the earlier element types.
304            for t in earlier_list_create_types.iter_mut() {
305                if let SqlScalarType::List {
306                    element_type,
307                    custom_id: _,
308                } = t
309                {
310                    **t = *element_type.clone();
311                    // These are not the same types anymore, so remove custom_ids all the
312                    // way down.
313                    let mut u = &mut **t;
314                    while let SqlScalarType::List {
315                        element_type,
316                        custom_id,
317                    } = u
318                    {
319                        *custom_id = None;
320                        u = &mut **element_type;
321                    }
322                } else {
323                    unreachable!("already matched below");
324                }
325            }
326        } else {
327            // We can't remove this index, so we can't reduce any of the ListCreates at this
328            // level. So we change list_create_mut_refs to refer to all the arguments of all
329            // the ListCreates currently referenced by list_create_mut_refs.
330            list_create_mut_refs = list_create_mut_refs
331                .into_iter()
332                .flat_map(|list_create| match list_create {
333                    MirScalarExpr::CallVariadic {
334                        func: VariadicFunc::ListCreate(ListCreate { elem_type }),
335                        exprs: list_create_args,
336                    } => {
337                        earlier_list_create_types.push(elem_type);
338                        list_create_args
339                    }
340                    // func cannot be anything else than a ListCreate
341                    _ => unreachable!(),
342                })
343                .collect();
344            i += 1;
345        }
346    }
347    // If all list indexes have been evaluated, return the reduced expression.
348    // Otherwise, rebuild the ListIndex call with the remaining ListCreates and indexes.
349    if index_exprs.is_empty() {
350        assert_eq!(list_create_mut_refs.len(), 1);
351        list_create_to_reduce
352    } else {
353        MirScalarExpr::call_variadic(
354            ListIndex,
355            std::iter::once(list_create_to_reduce)
356                .chain(index_exprs)
357                .collect(),
358        )
359    }
360}