mz_expr/scalar/func/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Variadic functions.
15
16use std::cmp;
17
18use chrono::NaiveDate;
19use fallible_iterator::FallibleIterator;
20use hmac::{Hmac, Mac};
21use itertools::Itertools;
22use md5::Md5;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::{CastFrom, ReinterpretCast};
25use mz_pgtz::timezone::TimezoneSpec;
26use mz_repr::adt::array::ArrayDimension;
27use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem};
28use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags};
29use mz_repr::adt::system::Oid;
30use mz_repr::adt::timestamp::CheckedTimestamp;
31use mz_repr::role_id::RoleId;
32use mz_repr::{
33    ColumnName, Datum, DatumType, ReprScalarType, Row, RowArena, SqlColumnType, SqlScalarType,
34};
35use serde::{Deserialize, Serialize};
36use sha1::Sha1;
37use sha2::{Sha224, Sha256, Sha384, Sha512};
38
39use crate::func::{
40    MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, parse_timezone,
41    regexp_match_static, regexp_replace_parse_flags, regexp_replace_static,
42    regexp_split_to_array_re, stringify_datum, timezone_time,
43};
44use crate::{EvalError, MirScalarExpr};
45
46pub fn and<'a>(
47    datums: &[Datum<'a>],
48    temp_storage: &'a RowArena,
49    exprs: &'a [MirScalarExpr],
50) -> Result<Datum<'a>, EvalError> {
51    // If any is false, then return false. Else, if any is null, then return null. Else, return true.
52    let mut null = false;
53    let mut err = None;
54    for expr in exprs {
55        match expr.eval(datums, temp_storage) {
56            Ok(Datum::False) => return Ok(Datum::False), // short-circuit
57            Ok(Datum::True) => {}
58            // No return in these two cases, because we might still see a false
59            Ok(Datum::Null) => null = true,
60            Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
61            _ => unreachable!(),
62        }
63    }
64    match (err, null) {
65        (Some(err), _) => Err(err),
66        (None, true) => Ok(Datum::Null),
67        (None, false) => Ok(Datum::True),
68    }
69}
70
71/// Constructs a new multidimensional array out of an arbitrary number of
72/// lower-dimensional arrays.
73///
74/// For example, if given three 1D arrays of length 2, this function will
75/// construct a 2D array with dimensions 3x2.
76///
77/// The input datums in `datums` must all be arrays of the same dimensions.
78/// (The arrays must also be of the same element type, but that is checked by
79/// the SQL type system, rather than checked here at runtime.)
80///
81/// If all input arrays are zero-dimensional arrays, then the output is a zero-
82/// dimensional array. Otherwise the lower bound of the additional dimension is
83/// one and the length of the new dimension is equal to `datums.len()`.
84fn array_create_multidim<'a>(
85    datums: &[Datum<'a>],
86    temp_storage: &'a RowArena,
87) -> Result<Datum<'a>, EvalError> {
88    // Per PostgreSQL, if all input arrays are zero dimensional, so is the
89    // output.
90    if datums.iter().all(|d| d.unwrap_array().dims().is_empty()) {
91        let dims = &[];
92        let datums = &[];
93        let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(dims, datums))?;
94        return Ok(datum);
95    }
96
97    let mut dims = vec![ArrayDimension {
98        lower_bound: 1,
99        length: datums.len(),
100    }];
101    if let Some(d) = datums.first() {
102        dims.extend(d.unwrap_array().dims());
103    };
104    let elements = datums
105        .iter()
106        .flat_map(|d| d.unwrap_array().elements().iter());
107    let datum =
108        temp_storage.try_make_datum(move |packer| packer.try_push_array(&dims, elements))?;
109    Ok(datum)
110}
111
112fn array_fill<'a>(
113    datums: &[Datum<'a>],
114    temp_storage: &'a RowArena,
115) -> Result<Datum<'a>, EvalError> {
116    const MAX_SIZE: usize = 1 << 28 - 1;
117    const NULL_ARR_ERR: &str = "dimension array or low bound array";
118    const NULL_ELEM_ERR: &str = "dimension values";
119
120    let fill = datums[0];
121    if matches!(fill, Datum::Array(_)) {
122        return Err(EvalError::Unsupported {
123            feature: "array_fill with arrays".into(),
124            discussion_no: None,
125        });
126    }
127
128    let arr = match datums[1] {
129        Datum::Null => return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into())),
130        o => o.unwrap_array(),
131    };
132
133    let dimensions = arr
134        .elements()
135        .iter()
136        .map(|d| match d {
137            Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
138            d => Ok(usize::cast_from(u32::reinterpret_cast(d.unwrap_int32()))),
139        })
140        .collect::<Result<Vec<_>, _>>()?;
141
142    let lower_bounds = match datums.get(2) {
143        Some(d) => {
144            let arr = match d {
145                Datum::Null => return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into())),
146                o => o.unwrap_array(),
147            };
148
149            arr.elements()
150                .iter()
151                .map(|l| match l {
152                    Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
153                    l => Ok(isize::cast_from(l.unwrap_int32())),
154                })
155                .collect::<Result<Vec<_>, _>>()?
156        }
157        None => {
158            vec![1isize; dimensions.len()]
159        }
160    };
161
162    if lower_bounds.len() != dimensions.len() {
163        return Err(EvalError::ArrayFillWrongArraySubscripts);
164    }
165
166    let fill_count: usize = dimensions
167        .iter()
168        .cloned()
169        .map(Some)
170        .reduce(|a, b| match (a, b) {
171            (Some(a), Some(b)) => a.checked_mul(b),
172            _ => None,
173        })
174        .flatten()
175        .ok_or(EvalError::MaxArraySizeExceeded(MAX_SIZE))?;
176
177    if matches!(
178        mz_repr::datum_size(&fill).checked_mul(fill_count),
179        None | Some(MAX_SIZE..)
180    ) {
181        return Err(EvalError::MaxArraySizeExceeded(MAX_SIZE));
182    }
183
184    let array_dimensions = if fill_count == 0 {
185        vec![ArrayDimension {
186            lower_bound: 1,
187            length: 0,
188        }]
189    } else {
190        dimensions
191            .into_iter()
192            .zip_eq(lower_bounds)
193            .map(|(length, lower_bound)| ArrayDimension {
194                lower_bound,
195                length,
196            })
197            .collect()
198    };
199
200    Ok(temp_storage.try_make_datum(|packer| {
201        packer.try_push_array(&array_dimensions, vec![fill; fill_count])
202    })?)
203}
204
205fn array_index<'a>(datums: &[Datum<'a>], offset: i64) -> Datum<'a> {
206    mz_ore::soft_assert_no_log!(offset == 0 || offset == 1, "offset must be either 0 or 1");
207
208    let array = datums[0].unwrap_array();
209    let dims = array.dims();
210    if dims.len() != datums.len() - 1 {
211        // You missed the datums "layer"
212        return Datum::Null;
213    }
214
215    let mut final_idx = 0;
216
217    for (d, idx) in dims.into_iter().zip_eq(datums[1..].iter()) {
218        // Lower bound is written in terms of 1-based indexing, which offset accounts for.
219        let idx = isize::cast_from(idx.unwrap_int64() + offset);
220
221        let (lower, upper) = d.dimension_bounds();
222
223        // This index missed all of the data at this layer. The dimension bounds are inclusive,
224        // while range checks are exclusive, so adjust.
225        if !(lower..upper + 1).contains(&idx) {
226            return Datum::Null;
227        }
228
229        // We discover how many indices our last index represents physically.
230        final_idx *= d.length;
231
232        // Because both index and lower bound are handled in 1-based indexing, taking their
233        // difference moves us back into 0-based indexing. Similarly, if the lower bound is
234        // negative, subtracting a negative value >= to itself ensures its non-negativity.
235        final_idx += usize::try_from(idx - d.lower_bound)
236            .expect("previous bounds check ensures phsical index is at least 0");
237    }
238
239    array
240        .elements()
241        .iter()
242        .nth(final_idx)
243        .unwrap_or(Datum::Null)
244}
245
246fn array_position<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
247    let array = match datums[0] {
248        Datum::Null => return Ok(Datum::Null),
249        o => o.unwrap_array(),
250    };
251
252    if array.dims().len() > 1 {
253        return Err(EvalError::MultiDimensionalArraySearch);
254    }
255
256    let search = datums[1];
257    if search == Datum::Null {
258        return Ok(Datum::Null);
259    }
260
261    let skip: usize = match datums.get(2) {
262        Some(Datum::Null) => return Err(EvalError::MustNotBeNull("initial position".into())),
263        None => 0,
264        Some(o) => usize::try_from(o.unwrap_int32())
265            .unwrap_or(0)
266            .saturating_sub(1),
267    };
268
269    let r = array.elements().iter().skip(skip).position(|d| d == search);
270
271    Ok(Datum::from(r.map(|p| {
272        // Adjust count for the amount we skipped, plus 1 for adjustng to PG indexing scheme.
273        i32::try_from(p + skip + 1).expect("fewer than i32::MAX elements in array")
274    })))
275}
276
277// WARNING: This function has potential OOM risk!
278// It is very difficult to calculate the output size ahead of time without knowing how to
279// calculate the stringified size of each element for all possible datatypes.
280fn array_to_string<'a>(
281    datums: &[Datum<'a>],
282    elem_type: &SqlScalarType,
283    temp_storage: &'a RowArena,
284) -> Result<Datum<'a>, EvalError> {
285    if datums[0].is_null() || datums[1].is_null() {
286        return Ok(Datum::Null);
287    }
288    let array = datums[0].unwrap_array();
289    let delimiter = datums[1].unwrap_str();
290    let null_str = match datums.get(2) {
291        None | Some(Datum::Null) => None,
292        Some(d) => Some(d.unwrap_str()),
293    };
294
295    let mut out = String::new();
296    for elem in array.elements().iter() {
297        if elem.is_null() {
298            if let Some(null_str) = null_str {
299                out.push_str(null_str);
300                out.push_str(delimiter);
301            }
302        } else {
303            stringify_datum(&mut out, elem, elem_type)?;
304            out.push_str(delimiter);
305        }
306    }
307    if out.len() > 0 {
308        // Lop off last delimiter only if string is not empty
309        out.truncate(out.len() - delimiter.len());
310    }
311    Ok(Datum::String(temp_storage.push_string(out)))
312}
313
314fn coalesce<'a>(
315    datums: &[Datum<'a>],
316    temp_storage: &'a RowArena,
317    exprs: &'a [MirScalarExpr],
318) -> Result<Datum<'a>, EvalError> {
319    for e in exprs {
320        let d = e.eval(datums, temp_storage)?;
321        if !d.is_null() {
322            return Ok(d);
323        }
324    }
325    Ok(Datum::Null)
326}
327
328fn create_range<'a>(
329    datums: &[Datum<'a>],
330    temp_storage: &'a RowArena,
331) -> Result<Datum<'a>, EvalError> {
332    let flags = match datums[2] {
333        Datum::Null => {
334            return Err(EvalError::InvalidRange(
335                InvalidRangeError::NullRangeBoundFlags,
336            ));
337        }
338        o => o.unwrap_str(),
339    };
340
341    let (lower_inclusive, upper_inclusive) = parse_range_bound_flags(flags)?;
342
343    let mut range = Range::new(Some((
344        RangeBound::new(datums[0], lower_inclusive),
345        RangeBound::new(datums[1], upper_inclusive),
346    )));
347
348    range.canonicalize()?;
349
350    Ok(temp_storage.make_datum(|row| {
351        row.push_range(range).expect("errors already handled");
352    }))
353}
354
355fn date_diff_date<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
356    let unit = unit.unwrap_str();
357    let unit = unit
358        .parse()
359        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
360
361    let a = a.unwrap_date();
362    let b = b.unwrap_date();
363
364    // Convert the Date into a timestamp so we can calculate age.
365    let a_ts = CheckedTimestamp::try_from(NaiveDate::from(a).and_hms_opt(0, 0, 0).unwrap())?;
366    let b_ts = CheckedTimestamp::try_from(NaiveDate::from(b).and_hms_opt(0, 0, 0).unwrap())?;
367    let diff = b_ts.diff_as(&a_ts, unit)?;
368
369    Ok(Datum::Int64(diff))
370}
371
372fn date_diff_time<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
373    let unit = unit.unwrap_str();
374    let unit = unit
375        .parse()
376        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
377
378    let a = a.unwrap_time();
379    let b = b.unwrap_time();
380
381    // Convert the Time into a timestamp so we can calculate age.
382    let a_ts =
383        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(a))?;
384    let b_ts =
385        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(b))?;
386    let diff = b_ts.diff_as(&a_ts, unit)?;
387
388    Ok(Datum::Int64(diff))
389}
390
391fn date_diff_timestamp<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
392    let unit = unit.unwrap_str();
393    let unit = unit
394        .parse()
395        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
396
397    let a = a.unwrap_timestamp();
398    let b = b.unwrap_timestamp();
399    let diff = b.diff_as(&a, unit)?;
400
401    Ok(Datum::Int64(diff))
402}
403
404fn date_diff_timestamptz<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
405    let unit = unit.unwrap_str();
406    let unit = unit
407        .parse()
408        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
409
410    let a = a.unwrap_timestamptz();
411    let b = b.unwrap_timestamptz();
412    let diff = b.diff_as(&a, unit)?;
413
414    Ok(Datum::Int64(diff))
415}
416
417fn error_if_null<'a>(
418    datums: &[Datum<'a>],
419    temp_storage: &'a RowArena,
420    exprs: &'a [MirScalarExpr],
421) -> Result<Datum<'a>, EvalError> {
422    let first = exprs[0].eval(datums, temp_storage)?;
423    match first {
424        Datum::Null => {
425            let err_msg = match exprs[1].eval(datums, temp_storage)? {
426                Datum::Null => {
427                    return Err(EvalError::Internal(
428                        "unexpected NULL in error side of error_if_null".into(),
429                    ));
430                }
431                o => o.unwrap_str(),
432            };
433            Err(EvalError::IfNullError(err_msg.into()))
434        }
435        _ => Ok(first),
436    }
437}
438
439fn greatest<'a>(
440    datums: &[Datum<'a>],
441    temp_storage: &'a RowArena,
442    exprs: &'a [MirScalarExpr],
443) -> Result<Datum<'a>, EvalError> {
444    let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
445    Ok(datums
446        .filter(|d| Ok(!d.is_null()))
447        .max()?
448        .unwrap_or(Datum::Null))
449}
450
451pub fn hmac_string<'a>(
452    datums: &[Datum<'a>],
453    temp_storage: &'a RowArena,
454) -> Result<Datum<'a>, EvalError> {
455    let to_digest = datums[0].unwrap_str().as_bytes();
456    let key = datums[1].unwrap_str().as_bytes();
457    let typ = datums[2].unwrap_str();
458    hmac_inner(to_digest, key, typ, temp_storage)
459}
460
461pub fn hmac_bytes<'a>(
462    datums: &[Datum<'a>],
463    temp_storage: &'a RowArena,
464) -> Result<Datum<'a>, EvalError> {
465    let to_digest = datums[0].unwrap_bytes();
466    let key = datums[1].unwrap_bytes();
467    let typ = datums[2].unwrap_str();
468    hmac_inner(to_digest, key, typ, temp_storage)
469}
470
471pub fn hmac_inner<'a>(
472    to_digest: &[u8],
473    key: &[u8],
474    typ: &str,
475    temp_storage: &'a RowArena,
476) -> Result<Datum<'a>, EvalError> {
477    let bytes = match typ {
478        "md5" => {
479            let mut mac = Hmac::<Md5>::new_from_slice(key).expect("HMAC accepts any key size");
480            mac.update(to_digest);
481            mac.finalize().into_bytes().to_vec()
482        }
483        "sha1" => {
484            let mut mac = Hmac::<Sha1>::new_from_slice(key).expect("HMAC accepts any key size");
485            mac.update(to_digest);
486            mac.finalize().into_bytes().to_vec()
487        }
488        "sha224" => {
489            let mut mac = Hmac::<Sha224>::new_from_slice(key).expect("HMAC accepts any key size");
490            mac.update(to_digest);
491            mac.finalize().into_bytes().to_vec()
492        }
493        "sha256" => {
494            let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC accepts any key size");
495            mac.update(to_digest);
496            mac.finalize().into_bytes().to_vec()
497        }
498        "sha384" => {
499            let mut mac = Hmac::<Sha384>::new_from_slice(key).expect("HMAC accepts any key size");
500            mac.update(to_digest);
501            mac.finalize().into_bytes().to_vec()
502        }
503        "sha512" => {
504            let mut mac = Hmac::<Sha512>::new_from_slice(key).expect("HMAC accepts any key size");
505            mac.update(to_digest);
506            mac.finalize().into_bytes().to_vec()
507        }
508        other => return Err(EvalError::InvalidHashAlgorithm(other.into())),
509    };
510    Ok(Datum::Bytes(temp_storage.push_bytes(bytes)))
511}
512
513fn jsonb_build_array<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
514    temp_storage.make_datum(|packer| {
515        packer.push_list(datums.into_iter().map(|d| match d {
516            Datum::Null => Datum::JsonNull,
517            d => *d,
518        }))
519    })
520}
521
522fn jsonb_build_object<'a>(
523    datums: &[Datum<'a>],
524    temp_storage: &'a RowArena,
525) -> Result<Datum<'a>, EvalError> {
526    let mut kvs = datums.chunks(2).collect::<Vec<_>>();
527    kvs.sort_by(|kv1, kv2| kv1[0].cmp(&kv2[0]));
528    kvs.dedup_by(|kv1, kv2| kv1[0] == kv2[0]);
529    temp_storage.try_make_datum(|packer| {
530        packer.push_dict_with(|packer| {
531            for kv in kvs {
532                let k = kv[0];
533                if k.is_null() {
534                    return Err(EvalError::KeyCannotBeNull);
535                };
536                let v = match kv[1] {
537                    Datum::Null => Datum::JsonNull,
538                    d => d,
539                };
540                packer.push(k);
541                packer.push(v);
542            }
543            Ok(())
544        })
545    })
546}
547
548fn least<'a>(
549    datums: &[Datum<'a>],
550    temp_storage: &'a RowArena,
551    exprs: &'a [MirScalarExpr],
552) -> Result<Datum<'a>, EvalError> {
553    let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
554    Ok(datums
555        .filter(|d| Ok(!d.is_null()))
556        .min()?
557        .unwrap_or(Datum::Null))
558}
559
560fn list_create<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
561    temp_storage.make_datum(|packer| packer.push_list(datums))
562}
563
564// TODO(benesch): remove potentially dangerous usage of `as`.
565#[allow(clippy::as_conversions)]
566fn list_index<'a>(datums: &[Datum<'a>]) -> Datum<'a> {
567    let mut buf = datums[0];
568
569    for i in datums[1..].iter() {
570        if buf.is_null() {
571            break;
572        }
573
574        let i = i.unwrap_int64();
575        if i < 1 {
576            return Datum::Null;
577        }
578
579        buf = buf
580            .unwrap_list()
581            .iter()
582            .nth(i as usize - 1)
583            .unwrap_or(Datum::Null);
584    }
585    buf
586}
587
588fn make_acl_item<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
589    let grantee = Oid(datums[0].unwrap_uint32());
590    let grantor = Oid(datums[1].unwrap_uint32());
591    let privileges = datums[2].unwrap_str();
592    let acl_mode = AclMode::parse_multiple_privileges(privileges)
593        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
594    let is_grantable = datums[3].unwrap_bool();
595    if is_grantable {
596        return Err(EvalError::Unsupported {
597            feature: "GRANT OPTION".into(),
598            discussion_no: None,
599        });
600    }
601
602    Ok(Datum::AclItem(AclItem {
603        grantee,
604        grantor,
605        acl_mode,
606    }))
607}
608
609fn make_mz_acl_item<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
610    let grantee: RoleId = datums[0]
611        .unwrap_str()
612        .parse()
613        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
614    let grantor: RoleId = datums[1]
615        .unwrap_str()
616        .parse()
617        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
618    if grantor == RoleId::Public {
619        return Err(EvalError::InvalidRoleId(
620            "mz_aclitem grantor cannot be PUBLIC role".into(),
621        ));
622    }
623    let privileges = datums[2].unwrap_str();
624    let acl_mode = AclMode::parse_multiple_privileges(privileges)
625        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
626
627    Ok(Datum::MzAclItem(MzAclItem {
628        grantee,
629        grantor,
630        acl_mode,
631    }))
632}
633
634// TODO(benesch): remove potentially dangerous usage of `as`.
635#[allow(clippy::as_conversions)]
636fn make_timestamp<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
637    let year: i32 = match datums[0].unwrap_int64().try_into() {
638        Ok(year) => year,
639        Err(_) => return Ok(Datum::Null),
640    };
641    let month: u32 = match datums[1].unwrap_int64().try_into() {
642        Ok(month) => month,
643        Err(_) => return Ok(Datum::Null),
644    };
645    let day: u32 = match datums[2].unwrap_int64().try_into() {
646        Ok(day) => day,
647        Err(_) => return Ok(Datum::Null),
648    };
649    let hour: u32 = match datums[3].unwrap_int64().try_into() {
650        Ok(day) => day,
651        Err(_) => return Ok(Datum::Null),
652    };
653    let minute: u32 = match datums[4].unwrap_int64().try_into() {
654        Ok(day) => day,
655        Err(_) => return Ok(Datum::Null),
656    };
657    let second_float = datums[5].unwrap_float64();
658    let second = second_float as u32;
659    let micros = ((second_float - second as f64) * 1_000_000.0) as u32;
660    let date = match NaiveDate::from_ymd_opt(year, month, day) {
661        Some(date) => date,
662        None => return Ok(Datum::Null),
663    };
664    let timestamp = match date.and_hms_micro_opt(hour, minute, second, micros) {
665        Some(timestamp) => timestamp,
666        None => return Ok(Datum::Null),
667    };
668    Ok(timestamp.try_into()?)
669}
670
671fn map_build<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
672    // Collect into a `BTreeMap` to provide the same semantics as it.
673    let map: std::collections::BTreeMap<&str, _> = datums
674        .into_iter()
675        .tuples()
676        .filter_map(|(k, v)| {
677            if k.is_null() {
678                None
679            } else {
680                Some((k.unwrap_str(), v))
681            }
682        })
683        .collect();
684
685    temp_storage.make_datum(|packer| packer.push_dict(map))
686}
687
688pub fn or<'a>(
689    datums: &[Datum<'a>],
690    temp_storage: &'a RowArena,
691    exprs: &'a [MirScalarExpr],
692) -> Result<Datum<'a>, EvalError> {
693    // If any is true, then return true. Else, if any is null, then return null. Else, return false.
694    let mut null = false;
695    let mut err = None;
696    for expr in exprs {
697        match expr.eval(datums, temp_storage) {
698            Ok(Datum::False) => {}
699            Ok(Datum::True) => return Ok(Datum::True), // short-circuit
700            // No return in these two cases, because we might still see a true
701            Ok(Datum::Null) => null = true,
702            Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
703            _ => unreachable!(),
704        }
705    }
706    match (err, null) {
707        (Some(err), _) => Err(err),
708        (None, true) => Ok(Datum::Null),
709        (None, false) => Ok(Datum::False),
710    }
711}
712
713fn pad_leading<'a>(
714    datums: &[Datum<'a>],
715    temp_storage: &'a RowArena,
716) -> Result<Datum<'a>, EvalError> {
717    let string = datums[0].unwrap_str();
718
719    let len = match usize::try_from(datums[1].unwrap_int32()) {
720        Ok(len) => len,
721        Err(_) => {
722            return Err(EvalError::InvalidParameterValue(
723                "length must be nonnegative".into(),
724            ));
725        }
726    };
727    if len > MAX_STRING_FUNC_RESULT_BYTES {
728        return Err(EvalError::LengthTooLarge);
729    }
730
731    let pad_string = if datums.len() == 3 {
732        datums[2].unwrap_str()
733    } else {
734        " "
735    };
736
737    let (end_char, end_char_byte_offset) = string
738        .chars()
739        .take(len)
740        .fold((0, 0), |acc, char| (acc.0 + 1, acc.1 + char.len_utf8()));
741
742    let mut buf = String::with_capacity(len);
743    if len == end_char {
744        buf.push_str(&string[0..end_char_byte_offset]);
745    } else {
746        buf.extend(pad_string.chars().cycle().take(len - end_char));
747        buf.push_str(string);
748    }
749
750    Ok(Datum::String(temp_storage.push_string(buf)))
751}
752
753fn regexp_match_dynamic<'a>(
754    datums: &[Datum<'a>],
755    temp_storage: &'a RowArena,
756) -> Result<Datum<'a>, EvalError> {
757    let haystack = datums[0];
758    let needle = datums[1].unwrap_str();
759    let flags = match datums.get(2) {
760        Some(d) => d.unwrap_str(),
761        None => "",
762    };
763    let needle = build_regex(needle, flags)?;
764    regexp_match_static(haystack, temp_storage, &needle)
765}
766
767fn regexp_split_to_array<'a>(
768    text: Datum<'a>,
769    regexp: Datum<'a>,
770    flags: Datum<'a>,
771    temp_storage: &'a RowArena,
772) -> Result<Datum<'a>, EvalError> {
773    let text = text.unwrap_str();
774    let regexp = regexp.unwrap_str();
775    let flags = flags.unwrap_str();
776    let regexp = build_regex(regexp, flags)?;
777    regexp_split_to_array_re(text, &regexp, temp_storage)
778}
779
780fn regexp_replace_dynamic<'a>(
781    datums: &[Datum<'a>],
782    temp_storage: &'a RowArena,
783) -> Result<Datum<'a>, EvalError> {
784    let source = datums[0];
785    let pattern = datums[1];
786    let replacement = datums[2];
787    let flags = match datums.get(3) {
788        Some(d) => d.unwrap_str(),
789        None => "",
790    };
791    let (limit, flags) = regexp_replace_parse_flags(flags);
792    let regexp = build_regex(pattern.unwrap_str(), &flags)?;
793    regexp_replace_static(source, replacement, &regexp, limit, temp_storage)
794}
795
796fn replace<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
797    // As a compromise to avoid always nearly duplicating the work of replace by doing size estimation,
798    // we first check if its possible for the fully replaced string to exceed the limit by assuming that
799    // every possible substring is replaced.
800    //
801    // If that estimate exceeds the limit, we then do a more precise (and expensive) estimate by counting
802    // the actual number of replacements that would occur, and using that to calculate the final size.
803    let text = datums[0].unwrap_str();
804    let from = datums[1].unwrap_str();
805    let to = datums[2].unwrap_str();
806    let possible_size = text.len() * to.len();
807    if possible_size > MAX_STRING_FUNC_RESULT_BYTES {
808        let replacement_count = text.matches(from).count();
809        let estimated_size = text.len() + replacement_count * (to.len().saturating_sub(from.len()));
810        if estimated_size > MAX_STRING_FUNC_RESULT_BYTES {
811            return Err(EvalError::LengthTooLarge);
812        }
813    }
814
815    Ok(Datum::String(
816        temp_storage.push_string(text.replace(from, to)),
817    ))
818}
819
820fn string_to_array<'a>(
821    string_datum: Datum<'a>,
822    delimiter: Datum<'a>,
823    null_string: Datum<'a>,
824    temp_storage: &'a RowArena,
825) -> Result<Datum<'a>, EvalError> {
826    if string_datum.is_null() {
827        return Ok(Datum::Null);
828    }
829
830    let string = string_datum.unwrap_str();
831
832    if string.is_empty() {
833        let mut row = Row::default();
834        let mut packer = row.packer();
835        packer.try_push_array(&[], std::iter::empty::<Datum>())?;
836
837        return Ok(temp_storage.push_unary_row(row));
838    }
839
840    if delimiter.is_null() {
841        let split_all_chars_delimiter = "";
842        return string_to_array_impl(string, split_all_chars_delimiter, null_string, temp_storage);
843    }
844
845    let delimiter = delimiter.unwrap_str();
846
847    if delimiter.is_empty() {
848        let mut row = Row::default();
849        let mut packer = row.packer();
850        packer.try_push_array(
851            &[ArrayDimension {
852                lower_bound: 1,
853                length: 1,
854            }],
855            vec![string].into_iter().map(Datum::String),
856        )?;
857
858        Ok(temp_storage.push_unary_row(row))
859    } else {
860        string_to_array_impl(string, delimiter, null_string, temp_storage)
861    }
862}
863
864fn string_to_array_impl<'a>(
865    string: &str,
866    delimiter: &str,
867    null_string: Datum<'a>,
868    temp_storage: &'a RowArena,
869) -> Result<Datum<'a>, EvalError> {
870    let mut row = Row::default();
871    let mut packer = row.packer();
872
873    let result = string.split(delimiter);
874    let found: Vec<&str> = if delimiter.is_empty() {
875        result.filter(|s| !s.is_empty()).collect()
876    } else {
877        result.collect()
878    };
879    let array_dimensions = [ArrayDimension {
880        lower_bound: 1,
881        length: found.len(),
882    }];
883
884    if null_string.is_null() {
885        packer.try_push_array(&array_dimensions, found.into_iter().map(Datum::String))?;
886    } else {
887        let null_string = null_string.unwrap_str();
888        let found_datums = found.into_iter().map(|chunk| {
889            if chunk.eq(null_string) {
890                Datum::Null
891            } else {
892                Datum::String(chunk)
893            }
894        });
895
896        packer.try_push_array(&array_dimensions, found_datums)?;
897    }
898
899    Ok(temp_storage.push_unary_row(row))
900}
901
902fn substr<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
903    let s: &'a str = datums[0].unwrap_str();
904
905    let raw_start_idx = i64::from(datums[1].unwrap_int32()) - 1;
906    let start_idx = match usize::try_from(cmp::max(raw_start_idx, 0)) {
907        Ok(i) => i,
908        Err(_) => {
909            return Err(EvalError::InvalidParameterValue(
910                format!(
911                    "substring starting index ({}) exceeds min/max position",
912                    raw_start_idx
913                )
914                .into(),
915            ));
916        }
917    };
918
919    let mut char_indices = s.char_indices();
920    let get_str_index = |(index, _char)| index;
921
922    let str_len = s.len();
923    let start_char_idx = char_indices.nth(start_idx).map_or(str_len, get_str_index);
924
925    if datums.len() == 3 {
926        let end_idx = match i64::from(datums[2].unwrap_int32()) {
927            e if e < 0 => {
928                return Err(EvalError::InvalidParameterValue(
929                    "negative substring length not allowed".into(),
930                ));
931            }
932            e if e == 0 || e + raw_start_idx < 1 => return Ok(Datum::String("")),
933            e => {
934                let e = cmp::min(raw_start_idx + e - 1, e - 1);
935                match usize::try_from(e) {
936                    Ok(i) => i,
937                    Err(_) => {
938                        return Err(EvalError::InvalidParameterValue(
939                            format!("substring length ({}) exceeds max position", e).into(),
940                        ));
941                    }
942                }
943            }
944        };
945
946        let end_char_idx = char_indices.nth(end_idx).map_or(str_len, get_str_index);
947
948        Ok(Datum::String(&s[start_char_idx..end_char_idx]))
949    } else {
950        Ok(Datum::String(&s[start_char_idx..]))
951    }
952}
953
954fn split_part<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
955    let string = datums[0].unwrap_str();
956    let delimiter = datums[1].unwrap_str();
957
958    // Provided index value begins at 1, not 0.
959    let index = match usize::try_from(i64::from(datums[2].unwrap_int32()) - 1) {
960        Ok(index) => index,
961        Err(_) => {
962            return Err(EvalError::InvalidParameterValue(
963                "field position must be greater than zero".into(),
964            ));
965        }
966    };
967
968    // If the provided delimiter is the empty string,
969    // PostgreSQL does not break the string into individual
970    // characters. Instead, it generates the following parts: [string].
971    if delimiter.is_empty() {
972        if index == 0 {
973            return Ok(datums[0]);
974        } else {
975            return Ok(Datum::String(""));
976        }
977    }
978
979    // If provided index is greater than the number of split parts,
980    // return an empty string.
981    Ok(Datum::String(
982        string.split(delimiter).nth(index).unwrap_or(""),
983    ))
984}
985
986fn text_concat_variadic<'a>(
987    datums: &[Datum<'a>],
988    temp_storage: &'a RowArena,
989) -> Result<Datum<'a>, EvalError> {
990    let mut total_size = 0;
991    for d in datums {
992        if !d.is_null() {
993            total_size += d.unwrap_str().len();
994            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
995                return Err(EvalError::LengthTooLarge);
996            }
997        }
998    }
999    let mut buf = String::new();
1000    for d in datums {
1001        if !d.is_null() {
1002            buf.push_str(d.unwrap_str());
1003        }
1004    }
1005    Ok(Datum::String(temp_storage.push_string(buf)))
1006}
1007
1008fn text_concat_ws<'a>(
1009    datums: &[Datum<'a>],
1010    temp_storage: &'a RowArena,
1011) -> Result<Datum<'a>, EvalError> {
1012    let ws = match datums[0] {
1013        Datum::Null => return Ok(Datum::Null),
1014        d => d.unwrap_str(),
1015    };
1016
1017    let mut total_size = 0;
1018    for d in &datums[1..] {
1019        if !d.is_null() {
1020            total_size += d.unwrap_str().len();
1021            total_size += ws.len();
1022            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1023                return Err(EvalError::LengthTooLarge);
1024            }
1025        }
1026    }
1027
1028    let buf = Itertools::join(
1029        &mut datums[1..].iter().filter_map(|d| match d {
1030            Datum::Null => None,
1031            d => Some(d.unwrap_str()),
1032        }),
1033        ws,
1034    );
1035
1036    Ok(Datum::String(temp_storage.push_string(buf)))
1037}
1038
1039fn translate<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
1040    let string = datums[0].unwrap_str();
1041    let from = datums[1].unwrap_str().chars().collect::<Vec<_>>();
1042    let to = datums[2].unwrap_str().chars().collect::<Vec<_>>();
1043
1044    Datum::String(
1045        temp_storage.push_string(
1046            string
1047                .chars()
1048                .filter_map(|c| match from.iter().position(|f| f == &c) {
1049                    Some(idx) => to.get(idx).copied(),
1050                    None => Some(c),
1051                })
1052                .collect(),
1053        ),
1054    )
1055}
1056
1057// TODO ///
1058
1059// TODO(benesch): remove potentially dangerous usage of `as`.
1060#[allow(clippy::as_conversions)]
1061fn list_slice_linear<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
1062    assert_eq!(
1063        datums.len() % 2,
1064        1,
1065        "expr::scalar::func::list_slice expects an odd number of arguments; 1 for list + 2 \
1066        for each start-end pair"
1067    );
1068    assert!(
1069        datums.len() > 2,
1070        "expr::scalar::func::list_slice expects at least 3 arguments; 1 for list + at least \
1071        one start-end pair"
1072    );
1073
1074    let mut start_idx = 0;
1075    let mut total_length = usize::MAX;
1076
1077    for (start, end) in datums[1..].iter().tuples::<(_, _)>() {
1078        let start = std::cmp::max(start.unwrap_int64(), 1);
1079        let end = end.unwrap_int64();
1080
1081        // Result should be empty list.
1082        if start > end {
1083            start_idx = 0;
1084            total_length = 0;
1085            break;
1086        }
1087
1088        let start_inner = start as usize - 1;
1089        // Start index only moves to geq positions.
1090        start_idx += start_inner;
1091
1092        // Length index only moves to leq positions
1093        let length_inner = (end - start) as usize + 1;
1094        total_length = std::cmp::min(length_inner, total_length - start_inner);
1095    }
1096
1097    let iter = datums[0]
1098        .unwrap_list()
1099        .iter()
1100        .skip(start_idx)
1101        .take(total_length);
1102
1103    temp_storage.make_datum(|row| {
1104        row.push_list_with(|row| {
1105            // if iter is empty, will get the appropriate empty list.
1106            for d in iter {
1107                row.push(d);
1108            }
1109        });
1110    })
1111}
1112
1113#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
1114pub enum VariadicFunc {
1115    Coalesce,
1116    Greatest,
1117    Least,
1118    Concat,
1119    ConcatWs,
1120    MakeTimestamp,
1121    PadLeading,
1122    Substr,
1123    Replace,
1124    JsonbBuildArray,
1125    JsonbBuildObject,
1126    MapBuild {
1127        value_type: SqlScalarType,
1128    },
1129    ArrayCreate {
1130        // We need to know the element type to type empty arrays.
1131        elem_type: SqlScalarType,
1132    },
1133    ArrayToString {
1134        elem_type: SqlScalarType,
1135    },
1136    ArrayIndex {
1137        // Adjusts the index by offset depending on whether being called on an array or an
1138        // Int2Vector.
1139        offset: i64,
1140    },
1141    ListCreate {
1142        // We need to know the element type to type empty lists.
1143        elem_type: SqlScalarType,
1144    },
1145    RecordCreate {
1146        field_names: Vec<ColumnName>,
1147    },
1148    ListIndex,
1149    ListSliceLinear,
1150    SplitPart,
1151    RegexpMatch,
1152    HmacString,
1153    HmacBytes,
1154    ErrorIfNull,
1155    DateBinTimestamp,
1156    DateBinTimestampTz,
1157    DateDiffTimestamp,
1158    DateDiffTimestampTz,
1159    DateDiffDate,
1160    DateDiffTime,
1161    And,
1162    Or,
1163    RangeCreate {
1164        elem_type: SqlScalarType,
1165    },
1166    MakeAclItem,
1167    MakeMzAclItem,
1168    Translate,
1169    ArrayPosition,
1170    ArrayFill {
1171        elem_type: SqlScalarType,
1172    },
1173    StringToArray,
1174    TimezoneTime,
1175    RegexpSplitToArray,
1176    RegexpReplace,
1177}
1178
1179impl VariadicFunc {
1180    pub fn eval<'a>(
1181        &'a self,
1182        datums: &[Datum<'a>],
1183        temp_storage: &'a RowArena,
1184        exprs: &'a [MirScalarExpr],
1185    ) -> Result<Datum<'a>, EvalError> {
1186        // Evaluate all non-eager functions directly
1187        match self {
1188            VariadicFunc::Coalesce => return coalesce(datums, temp_storage, exprs),
1189            VariadicFunc::Greatest => return greatest(datums, temp_storage, exprs),
1190            VariadicFunc::And => return and(datums, temp_storage, exprs),
1191            VariadicFunc::Or => return or(datums, temp_storage, exprs),
1192            VariadicFunc::ErrorIfNull => return error_if_null(datums, temp_storage, exprs),
1193            VariadicFunc::Least => return least(datums, temp_storage, exprs),
1194            _ => {}
1195        };
1196
1197        // Compute parameters to eager functions
1198        let ds = exprs
1199            .iter()
1200            .map(|e| e.eval(datums, temp_storage))
1201            .collect::<Result<Vec<_>, _>>()?;
1202        // Check NULL propagation
1203        if self.propagates_nulls() && ds.iter().any(|d| d.is_null()) {
1204            return Ok(Datum::Null);
1205        }
1206
1207        // Evaluate eager functions
1208        match self {
1209            VariadicFunc::Coalesce
1210            | VariadicFunc::Greatest
1211            | VariadicFunc::And
1212            | VariadicFunc::Or
1213            | VariadicFunc::ErrorIfNull
1214            | VariadicFunc::Least => unreachable!(),
1215            VariadicFunc::Concat => text_concat_variadic(&ds, temp_storage),
1216            VariadicFunc::ConcatWs => text_concat_ws(&ds, temp_storage),
1217            VariadicFunc::MakeTimestamp => make_timestamp(&ds),
1218            VariadicFunc::PadLeading => pad_leading(&ds, temp_storage),
1219            VariadicFunc::Substr => substr(&ds),
1220            VariadicFunc::Replace => replace(&ds, temp_storage),
1221            VariadicFunc::Translate => Ok(translate(&ds, temp_storage)),
1222            VariadicFunc::JsonbBuildArray => Ok(jsonb_build_array(&ds, temp_storage)),
1223            VariadicFunc::JsonbBuildObject => jsonb_build_object(&ds, temp_storage),
1224            VariadicFunc::MapBuild { .. } => Ok(map_build(&ds, temp_storage)),
1225            VariadicFunc::ArrayCreate {
1226                elem_type: SqlScalarType::Array(_),
1227            } => array_create_multidim(&ds, temp_storage),
1228            VariadicFunc::ArrayCreate { .. } => array_create_scalar(&ds, temp_storage),
1229            VariadicFunc::ArrayToString { elem_type } => {
1230                array_to_string(&ds, elem_type, temp_storage)
1231            }
1232            VariadicFunc::ArrayIndex { offset } => Ok(array_index(&ds, *offset)),
1233
1234            VariadicFunc::ListCreate { .. } | VariadicFunc::RecordCreate { .. } => {
1235                Ok(list_create(&ds, temp_storage))
1236            }
1237            VariadicFunc::ListIndex => Ok(list_index(&ds)),
1238            VariadicFunc::ListSliceLinear => Ok(list_slice_linear(&ds, temp_storage)),
1239            VariadicFunc::SplitPart => split_part(&ds),
1240            VariadicFunc::RegexpMatch => regexp_match_dynamic(&ds, temp_storage),
1241            VariadicFunc::HmacString => hmac_string(&ds, temp_storage),
1242            VariadicFunc::HmacBytes => hmac_bytes(&ds, temp_storage),
1243            VariadicFunc::DateBinTimestamp => date_bin(
1244                ds[0].unwrap_interval(),
1245                ds[1].unwrap_timestamp(),
1246                ds[2].unwrap_timestamp(),
1247            )
1248            .into_result(temp_storage),
1249            VariadicFunc::DateBinTimestampTz => date_bin(
1250                ds[0].unwrap_interval(),
1251                ds[1].unwrap_timestamptz(),
1252                ds[2].unwrap_timestamptz(),
1253            )
1254            .into_result(temp_storage),
1255            VariadicFunc::DateDiffTimestamp => date_diff_timestamp(ds[0], ds[1], ds[2]),
1256            VariadicFunc::DateDiffTimestampTz => date_diff_timestamptz(ds[0], ds[1], ds[2]),
1257            VariadicFunc::DateDiffDate => date_diff_date(ds[0], ds[1], ds[2]),
1258            VariadicFunc::DateDiffTime => date_diff_time(ds[0], ds[1], ds[2]),
1259            VariadicFunc::RangeCreate { .. } => create_range(&ds, temp_storage),
1260            VariadicFunc::MakeAclItem => make_acl_item(&ds),
1261            VariadicFunc::MakeMzAclItem => make_mz_acl_item(&ds),
1262            VariadicFunc::ArrayPosition => array_position(&ds),
1263            VariadicFunc::ArrayFill { .. } => array_fill(&ds, temp_storage),
1264            VariadicFunc::TimezoneTime => parse_timezone(ds[0].unwrap_str(), TimezoneSpec::Posix)
1265                .map(|tz| {
1266                    timezone_time(
1267                        tz,
1268                        ds[1].unwrap_time(),
1269                        &ds[2].unwrap_timestamptz().naive_utc(),
1270                    )
1271                    .into()
1272                }),
1273            VariadicFunc::RegexpSplitToArray => {
1274                let flags = if ds.len() == 2 {
1275                    Datum::String("")
1276                } else {
1277                    ds[2]
1278                };
1279                regexp_split_to_array(ds[0], ds[1], flags, temp_storage)
1280            }
1281            VariadicFunc::RegexpReplace => regexp_replace_dynamic(&ds, temp_storage),
1282            VariadicFunc::StringToArray => {
1283                let null_string = if ds.len() == 2 { Datum::Null } else { ds[2] };
1284
1285                string_to_array(ds[0], ds[1], null_string, temp_storage)
1286            }
1287        }
1288    }
1289
1290    pub fn is_associative(&self) -> bool {
1291        match self {
1292            VariadicFunc::Coalesce
1293            | VariadicFunc::Greatest
1294            | VariadicFunc::Least
1295            | VariadicFunc::Concat
1296            | VariadicFunc::And
1297            | VariadicFunc::Or => true,
1298
1299            VariadicFunc::MakeTimestamp
1300            | VariadicFunc::PadLeading
1301            | VariadicFunc::ConcatWs
1302            | VariadicFunc::Substr
1303            | VariadicFunc::Replace
1304            | VariadicFunc::Translate
1305            | VariadicFunc::JsonbBuildArray
1306            | VariadicFunc::JsonbBuildObject
1307            | VariadicFunc::MapBuild { value_type: _ }
1308            | VariadicFunc::ArrayCreate { elem_type: _ }
1309            | VariadicFunc::ArrayToString { elem_type: _ }
1310            | VariadicFunc::ArrayIndex { offset: _ }
1311            | VariadicFunc::ListCreate { elem_type: _ }
1312            | VariadicFunc::RecordCreate { field_names: _ }
1313            | VariadicFunc::ListIndex
1314            | VariadicFunc::ListSliceLinear
1315            | VariadicFunc::SplitPart
1316            | VariadicFunc::RegexpMatch
1317            | VariadicFunc::HmacString
1318            | VariadicFunc::HmacBytes
1319            | VariadicFunc::ErrorIfNull
1320            | VariadicFunc::DateBinTimestamp
1321            | VariadicFunc::DateBinTimestampTz
1322            | VariadicFunc::DateDiffTimestamp
1323            | VariadicFunc::DateDiffTimestampTz
1324            | VariadicFunc::DateDiffDate
1325            | VariadicFunc::DateDiffTime
1326            | VariadicFunc::RangeCreate { .. }
1327            | VariadicFunc::MakeAclItem
1328            | VariadicFunc::MakeMzAclItem
1329            | VariadicFunc::ArrayPosition
1330            | VariadicFunc::ArrayFill { .. }
1331            | VariadicFunc::TimezoneTime
1332            | VariadicFunc::RegexpSplitToArray
1333            | VariadicFunc::StringToArray
1334            | VariadicFunc::RegexpReplace => false,
1335        }
1336    }
1337
1338    pub fn output_type(&self, input_types: Vec<SqlColumnType>) -> SqlColumnType {
1339        use VariadicFunc::*;
1340        let in_nullable = input_types.iter().any(|t| t.nullable);
1341        match self {
1342            Greatest | Least => input_types
1343                .into_iter()
1344                .reduce(|l, r| l.union(&r).unwrap())
1345                .unwrap(),
1346            Coalesce => {
1347                // Note that the parser doesn't allow empty argument lists for variadic functions
1348                // that use the standard function call syntax (ArrayCreate and co. are different
1349                // because of the special syntax for calling them).
1350                let nullable = input_types.iter().all(|typ| typ.nullable);
1351                input_types
1352                    .into_iter()
1353                    .reduce(|l, r| l.union(&r).unwrap())
1354                    .unwrap()
1355                    .nullable(nullable)
1356            }
1357            Concat | ConcatWs => SqlScalarType::String.nullable(in_nullable),
1358            MakeTimestamp => SqlScalarType::Timestamp { precision: None }.nullable(true),
1359            PadLeading => SqlScalarType::String.nullable(in_nullable),
1360            Substr => SqlScalarType::String.nullable(in_nullable),
1361            Replace => SqlScalarType::String.nullable(in_nullable),
1362            Translate => SqlScalarType::String.nullable(in_nullable),
1363            JsonbBuildArray | JsonbBuildObject => SqlScalarType::Jsonb.nullable(true),
1364            MapBuild { value_type } => SqlScalarType::Map {
1365                value_type: Box::new(value_type.clone()),
1366                custom_id: None,
1367            }
1368            .nullable(true),
1369            ArrayCreate { elem_type } => {
1370                debug_assert!(
1371                    input_types
1372                        .iter()
1373                        .all(|t| ReprScalarType::from(&t.scalar_type)
1374                            == ReprScalarType::from(elem_type)),
1375                    "Args to ArrayCreate should have types that are repr-compatible with the elem_type"
1376                );
1377                match elem_type {
1378                    SqlScalarType::Array(_) => elem_type.clone().nullable(false),
1379                    _ => SqlScalarType::Array(Box::new(elem_type.clone())).nullable(false),
1380                }
1381            }
1382            ArrayToString { .. } => SqlScalarType::String.nullable(in_nullable),
1383            ArrayIndex { .. } => input_types[0]
1384                .scalar_type
1385                .unwrap_array_element_type()
1386                .clone()
1387                .nullable(true),
1388            ListCreate { elem_type } => {
1389                // commented out to work around
1390                // https://github.com/MaterializeInc/database-issues/issues/2730
1391                // soft_assert!(
1392                //     input_types.iter().all(|t| t.scalar_type.base_eq(elem_type)),
1393                //     "{}", format!("Args to ListCreate should have types that are compatible with the elem_type.\nArgs:{:#?}\nelem_type:{:#?}", input_types, elem_type)
1394                // );
1395                SqlScalarType::List {
1396                    element_type: Box::new(elem_type.clone()),
1397                    custom_id: None,
1398                }
1399                .nullable(false)
1400            }
1401            ListIndex => input_types[0]
1402                .scalar_type
1403                .unwrap_list_nth_layer_type(input_types.len() - 1)
1404                .clone()
1405                .nullable(true),
1406            ListSliceLinear { .. } => input_types[0].scalar_type.clone().nullable(in_nullable),
1407            RecordCreate { field_names } => SqlScalarType::Record {
1408                fields: field_names
1409                    .clone()
1410                    .into_iter()
1411                    .zip_eq(input_types)
1412                    .collect(),
1413                custom_id: None,
1414            }
1415            .nullable(false),
1416            SplitPart => SqlScalarType::String.nullable(in_nullable),
1417            RegexpMatch => SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true),
1418            HmacString | HmacBytes => SqlScalarType::Bytes.nullable(in_nullable),
1419            ErrorIfNull => input_types[0].scalar_type.clone().nullable(false),
1420            DateBinTimestamp => SqlScalarType::Timestamp { precision: None }.nullable(in_nullable),
1421            DateBinTimestampTz => {
1422                SqlScalarType::TimestampTz { precision: None }.nullable(in_nullable)
1423            }
1424            DateDiffTimestamp => SqlScalarType::Int64.nullable(in_nullable),
1425            DateDiffTimestampTz => SqlScalarType::Int64.nullable(in_nullable),
1426            DateDiffDate => SqlScalarType::Int64.nullable(in_nullable),
1427            DateDiffTime => SqlScalarType::Int64.nullable(in_nullable),
1428            And | Or => SqlScalarType::Bool.nullable(in_nullable),
1429            RangeCreate { elem_type } => SqlScalarType::Range {
1430                element_type: Box::new(elem_type.clone()),
1431            }
1432            .nullable(false),
1433            MakeAclItem => SqlScalarType::AclItem.nullable(true),
1434            MakeMzAclItem => SqlScalarType::MzAclItem.nullable(true),
1435            ArrayPosition => SqlScalarType::Int32.nullable(true),
1436            ArrayFill { elem_type } => {
1437                SqlScalarType::Array(Box::new(elem_type.clone())).nullable(false)
1438            }
1439            TimezoneTime => SqlScalarType::Time.nullable(in_nullable),
1440            RegexpSplitToArray => {
1441                SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(in_nullable)
1442            }
1443            RegexpReplace => SqlScalarType::String.nullable(in_nullable),
1444            StringToArray => SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true),
1445        }
1446    }
1447
1448    /// Whether the function output is NULL if any of its inputs are NULL.
1449    ///
1450    /// NB: if any input is NULL the output will be returned as NULL without
1451    /// calling the function.
1452    pub fn propagates_nulls(&self) -> bool {
1453        // NOTE: The following is a list of the variadic functions
1454        // that **DO NOT** propagate nulls.
1455        !matches!(
1456            self,
1457            VariadicFunc::And
1458                | VariadicFunc::Or
1459                | VariadicFunc::Coalesce
1460                | VariadicFunc::Greatest
1461                | VariadicFunc::Least
1462                | VariadicFunc::Concat
1463                | VariadicFunc::ConcatWs
1464                | VariadicFunc::JsonbBuildArray
1465                | VariadicFunc::JsonbBuildObject
1466                | VariadicFunc::MapBuild { .. }
1467                | VariadicFunc::ListCreate { .. }
1468                | VariadicFunc::RecordCreate { .. }
1469                | VariadicFunc::ArrayCreate { .. }
1470                | VariadicFunc::ArrayToString { .. }
1471                | VariadicFunc::ErrorIfNull
1472                | VariadicFunc::RangeCreate { .. }
1473                | VariadicFunc::ArrayPosition
1474                | VariadicFunc::ArrayFill { .. }
1475                | VariadicFunc::StringToArray
1476        )
1477    }
1478
1479    /// Whether the function might return NULL even if none of its inputs are
1480    /// NULL.
1481    ///
1482    /// This is presently conservative, and may indicate that a function
1483    /// introduces nulls even when it does not.
1484    pub fn introduces_nulls(&self) -> bool {
1485        use VariadicFunc::*;
1486        match self {
1487            Concat
1488            | ConcatWs
1489            | PadLeading
1490            | Substr
1491            | Replace
1492            | Translate
1493            | JsonbBuildArray
1494            | JsonbBuildObject
1495            | MapBuild { .. }
1496            | ArrayCreate { .. }
1497            | ArrayToString { .. }
1498            | ListCreate { .. }
1499            | RecordCreate { .. }
1500            | ListSliceLinear
1501            | SplitPart
1502            | HmacString
1503            | HmacBytes
1504            | ErrorIfNull
1505            | DateBinTimestamp
1506            | DateBinTimestampTz
1507            | DateDiffTimestamp
1508            | DateDiffTimestampTz
1509            | DateDiffDate
1510            | DateDiffTime
1511            | RangeCreate { .. }
1512            | And
1513            | Or
1514            | MakeAclItem
1515            | MakeMzAclItem
1516            | ArrayPosition
1517            | ArrayFill { .. }
1518            | TimezoneTime
1519            | RegexpSplitToArray
1520            | RegexpReplace => false,
1521            Coalesce
1522            | Greatest
1523            | Least
1524            | MakeTimestamp
1525            | ArrayIndex { .. }
1526            | StringToArray
1527            | ListIndex
1528            | RegexpMatch => true,
1529        }
1530    }
1531
1532    pub fn switch_and_or(&self) -> Self {
1533        match self {
1534            VariadicFunc::And => VariadicFunc::Or,
1535            VariadicFunc::Or => VariadicFunc::And,
1536            _ => unreachable!(),
1537        }
1538    }
1539
1540    pub fn is_infix_op(&self) -> bool {
1541        use VariadicFunc::*;
1542        matches!(self, And | Or)
1543    }
1544
1545    /// Gives the unit (u) of OR or AND, such that `u AND/OR x == x`.
1546    /// Note that a 0-arg AND/OR evaluates to unit_of_and_or.
1547    pub fn unit_of_and_or(&self) -> MirScalarExpr {
1548        match self {
1549            VariadicFunc::And => MirScalarExpr::literal_true(),
1550            VariadicFunc::Or => MirScalarExpr::literal_false(),
1551            _ => unreachable!(),
1552        }
1553    }
1554
1555    /// Gives the zero (z) of OR or AND, such that `z AND/OR x == z`.
1556    pub fn zero_of_and_or(&self) -> MirScalarExpr {
1557        match self {
1558            VariadicFunc::And => MirScalarExpr::literal_false(),
1559            VariadicFunc::Or => MirScalarExpr::literal_true(),
1560            _ => unreachable!(),
1561        }
1562    }
1563
1564    /// Returns true if the function could introduce an error on non-error inputs.
1565    pub fn could_error(&self) -> bool {
1566        match self {
1567            VariadicFunc::And | VariadicFunc::Or => false,
1568            VariadicFunc::Coalesce => false,
1569            VariadicFunc::Greatest | VariadicFunc::Least => false,
1570            VariadicFunc::Concat | VariadicFunc::ConcatWs => false,
1571            VariadicFunc::Replace => false,
1572            VariadicFunc::Translate => false,
1573            VariadicFunc::ArrayIndex { .. } => false,
1574            VariadicFunc::ListCreate { .. } | VariadicFunc::RecordCreate { .. } => false,
1575            // All other cases are unknown
1576            _ => true,
1577        }
1578    }
1579
1580    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
1581    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
1582    /// determine the range of possible outputs just by mapping the endpoints.
1583    ///
1584    /// This describes the *pointwise* behaviour of the function:
1585    /// ie. if more than one argument is provided, this describes the behaviour of
1586    /// any specific argument as the others are held constant. (For example, `COALESCE(a, b)` is
1587    /// monotone in `a` because for any particular value of `b`, increasing `a` will never
1588    /// cause the result to decrease.)
1589    ///
1590    /// This property describes the behaviour of the function over ranges where the function is defined:
1591    /// ie. the arguments and the result are non-error datums.
1592    pub fn is_monotone(&self) -> bool {
1593        match self {
1594            VariadicFunc::Coalesce
1595            | VariadicFunc::Greatest
1596            | VariadicFunc::Least
1597            | VariadicFunc::And
1598            | VariadicFunc::Or => true,
1599            VariadicFunc::Concat
1600            | VariadicFunc::ConcatWs
1601            | VariadicFunc::MakeTimestamp
1602            | VariadicFunc::PadLeading
1603            | VariadicFunc::Substr
1604            | VariadicFunc::Replace
1605            | VariadicFunc::JsonbBuildArray
1606            | VariadicFunc::JsonbBuildObject
1607            | VariadicFunc::MapBuild { .. }
1608            | VariadicFunc::ArrayCreate { .. }
1609            | VariadicFunc::ArrayToString { .. }
1610            | VariadicFunc::ArrayIndex { .. }
1611            | VariadicFunc::ListCreate { .. }
1612            | VariadicFunc::RecordCreate { .. }
1613            | VariadicFunc::ListIndex
1614            | VariadicFunc::ListSliceLinear
1615            | VariadicFunc::SplitPart
1616            | VariadicFunc::RegexpMatch
1617            | VariadicFunc::HmacString
1618            | VariadicFunc::HmacBytes
1619            | VariadicFunc::ErrorIfNull
1620            | VariadicFunc::DateBinTimestamp
1621            | VariadicFunc::DateBinTimestampTz
1622            | VariadicFunc::RangeCreate { .. }
1623            | VariadicFunc::MakeAclItem
1624            | VariadicFunc::MakeMzAclItem
1625            | VariadicFunc::Translate
1626            | VariadicFunc::ArrayPosition
1627            | VariadicFunc::ArrayFill { .. }
1628            | VariadicFunc::DateDiffTimestamp
1629            | VariadicFunc::DateDiffTimestampTz
1630            | VariadicFunc::DateDiffDate
1631            | VariadicFunc::DateDiffTime
1632            | VariadicFunc::TimezoneTime
1633            | VariadicFunc::RegexpSplitToArray
1634            | VariadicFunc::StringToArray
1635            | VariadicFunc::RegexpReplace => false,
1636        }
1637    }
1638}
1639
1640impl std::fmt::Display for VariadicFunc {
1641    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1642        match self {
1643            VariadicFunc::Coalesce => f.write_str("coalesce"),
1644            VariadicFunc::Greatest => f.write_str("greatest"),
1645            VariadicFunc::Least => f.write_str("least"),
1646            VariadicFunc::Concat => f.write_str("concat"),
1647            VariadicFunc::ConcatWs => f.write_str("concat_ws"),
1648            VariadicFunc::MakeTimestamp => f.write_str("makets"),
1649            VariadicFunc::PadLeading => f.write_str("lpad"),
1650            VariadicFunc::Substr => f.write_str("substr"),
1651            VariadicFunc::Replace => f.write_str("replace"),
1652            VariadicFunc::Translate => f.write_str("translate"),
1653            VariadicFunc::JsonbBuildArray => f.write_str("jsonb_build_array"),
1654            VariadicFunc::JsonbBuildObject => f.write_str("jsonb_build_object"),
1655            VariadicFunc::MapBuild { .. } => f.write_str("map_build"),
1656            VariadicFunc::ArrayCreate { .. } => f.write_str("array_create"),
1657            VariadicFunc::ArrayToString { .. } => f.write_str("array_to_string"),
1658            VariadicFunc::ArrayIndex { .. } => f.write_str("array_index"),
1659            VariadicFunc::ListCreate { .. } => f.write_str("list_create"),
1660            VariadicFunc::RecordCreate { .. } => f.write_str("record_create"),
1661            VariadicFunc::ListIndex => f.write_str("list_index"),
1662            VariadicFunc::ListSliceLinear => f.write_str("list_slice_linear"),
1663            VariadicFunc::SplitPart => f.write_str("split_string"),
1664            VariadicFunc::RegexpMatch => f.write_str("regexp_match"),
1665            VariadicFunc::HmacString | VariadicFunc::HmacBytes => f.write_str("hmac"),
1666            VariadicFunc::ErrorIfNull => f.write_str("error_if_null"),
1667            VariadicFunc::DateBinTimestamp => f.write_str("timestamp_bin"),
1668            VariadicFunc::DateBinTimestampTz => f.write_str("timestamptz_bin"),
1669            VariadicFunc::DateDiffTimestamp
1670            | VariadicFunc::DateDiffTimestampTz
1671            | VariadicFunc::DateDiffDate
1672            | VariadicFunc::DateDiffTime => f.write_str("datediff"),
1673            VariadicFunc::And => f.write_str("AND"),
1674            VariadicFunc::Or => f.write_str("OR"),
1675            VariadicFunc::RangeCreate {
1676                elem_type: element_type,
1677            } => f.write_str(match element_type {
1678                SqlScalarType::Int32 => "int4range",
1679                SqlScalarType::Int64 => "int8range",
1680                SqlScalarType::Date => "daterange",
1681                SqlScalarType::Numeric { .. } => "numrange",
1682                SqlScalarType::Timestamp { .. } => "tsrange",
1683                SqlScalarType::TimestampTz { .. } => "tstzrange",
1684                _ => unreachable!(),
1685            }),
1686            VariadicFunc::MakeAclItem => f.write_str("makeaclitem"),
1687            VariadicFunc::MakeMzAclItem => f.write_str("make_mz_aclitem"),
1688            VariadicFunc::ArrayPosition => f.write_str("array_position"),
1689            VariadicFunc::ArrayFill { .. } => f.write_str("array_fill"),
1690            VariadicFunc::TimezoneTime => f.write_str("timezonet"),
1691            VariadicFunc::RegexpSplitToArray => f.write_str("regexp_split_to_array"),
1692            VariadicFunc::RegexpReplace => f.write_str("regexp_replace"),
1693            VariadicFunc::StringToArray => f.write_str("string_to_array"),
1694        }
1695    }
1696}