Skip to main content

mz_expr/scalar/func/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Variadic functions.
15
16use std::borrow::Cow;
17use std::cmp;
18
19use chrono::NaiveDate;
20use fallible_iterator::FallibleIterator;
21use hmac::{Hmac, Mac};
22use itertools::Itertools;
23use md5::Md5;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, ReinterpretCast};
26use mz_ore::soft_assert_or_log;
27use mz_pgtz::timezone::TimezoneSpec;
28use mz_repr::adt::array::{ArrayDimension, ArrayDimensions, InvalidArrayError};
29use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem};
30use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags};
31use mz_repr::adt::system::Oid;
32use mz_repr::adt::timestamp::CheckedTimestamp;
33use mz_repr::role_id::RoleId;
34use mz_repr::{
35    ColumnName, Datum, OutputDatumType, ReprScalarType, Row, RowArena, SqlColumnType, SqlScalarType,
36};
37use serde::{Deserialize, Serialize};
38use sha1::Sha1;
39use sha2::{Sha224, Sha256, Sha384, Sha512};
40
41use crate::func::{
42    MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, parse_timezone,
43    regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re, stringify_datum,
44    timezone_time,
45};
46use crate::{EvalError, MirScalarExpr};
47
48pub fn and<'a>(
49    datums: &[Datum<'a>],
50    temp_storage: &'a RowArena,
51    exprs: &'a [MirScalarExpr],
52) -> Result<Datum<'a>, EvalError> {
53    // If any is false, then return false. Else, if any is null, then return null. Else, return true.
54    let mut null = false;
55    let mut err = None;
56    for expr in exprs {
57        match expr.eval(datums, temp_storage) {
58            Ok(Datum::False) => return Ok(Datum::False), // short-circuit
59            Ok(Datum::True) => {}
60            // No return in these two cases, because we might still see a false
61            Ok(Datum::Null) => null = true,
62            Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
63            _ => unreachable!(),
64        }
65    }
66    match (err, null) {
67        (Some(err), _) => Err(err),
68        (None, true) => Ok(Datum::Null),
69        (None, false) => Ok(Datum::True),
70    }
71}
72
73/// Constructs a new multidimensional array out of an arbitrary number of
74/// lower-dimensional arrays.
75///
76/// For example, if given three 1D arrays of length 2, this function will
77/// construct a 2D array with dimensions 3x2.
78///
79/// The input datums in `datums` must all be arrays of the same dimensions.
80/// (The arrays must also be of the same element type, but that is checked by
81/// the SQL type system, rather than checked here at runtime.)
82///
83/// If all input arrays are zero-dimensional arrays, then the output is a zero-
84/// dimensional array. Otherwise, the lower bound of the additional dimension is
85/// one and the length of the new dimension is equal to `datums.len()`.
86///
87/// Null elements are allowed and considered to be zero-dimensional arrays.
88fn array_create_multidim<'a>(
89    datums: &[Datum<'a>],
90    temp_storage: &'a RowArena,
91) -> Result<Datum<'a>, EvalError> {
92    let mut dim: Option<ArrayDimensions> = None;
93    for datum in datums {
94        let actual_dims = match datum {
95            Datum::Null => ArrayDimensions::default(),
96            Datum::Array(arr) => arr.dims(),
97            d => panic!("unexpected datum {d}"),
98        };
99        if let Some(expected) = &dim {
100            if actual_dims.ndims() != expected.ndims() {
101                let actual = actual_dims.ndims().into();
102                let expected = expected.ndims().into();
103                // All input arrays must have the same dimensionality.
104                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
105            }
106            if let Some((e, a)) = expected
107                .into_iter()
108                .zip_eq(actual_dims.into_iter())
109                .find(|(e, a)| e != a)
110            {
111                let actual = a.length;
112                let expected = e.length;
113                // All input arrays must have the same dimensionality.
114                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
115            }
116        }
117        dim = Some(actual_dims);
118    }
119    // Per PostgreSQL, if all input arrays are zero dimensional, so is the output.
120    if dim.as_ref().map_or(true, ArrayDimensions::is_empty) {
121        return Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&[], &[]))?);
122    }
123
124    let mut dims = vec![ArrayDimension {
125        lower_bound: 1,
126        length: datums.len(),
127    }];
128    if let Some(d) = datums.first() {
129        dims.extend(d.unwrap_array().dims());
130    };
131    let elements = datums
132        .iter()
133        .flat_map(|d| d.unwrap_array().elements().iter());
134    let datum =
135        temp_storage.try_make_datum(move |packer| packer.try_push_array(&dims, elements))?;
136    Ok(datum)
137}
138
139fn array_fill<'a>(
140    datums: &[Datum<'a>],
141    temp_storage: &'a RowArena,
142) -> Result<Datum<'a>, EvalError> {
143    const MAX_SIZE: usize = 1 << 28 - 1;
144    const NULL_ARR_ERR: &str = "dimension array or low bound array";
145    const NULL_ELEM_ERR: &str = "dimension values";
146
147    let fill = datums[0];
148    if matches!(fill, Datum::Array(_)) {
149        return Err(EvalError::Unsupported {
150            feature: "array_fill with arrays".into(),
151            discussion_no: None,
152        });
153    }
154
155    let arr = match datums[1] {
156        Datum::Null => return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into())),
157        o => o.unwrap_array(),
158    };
159
160    let dimensions = arr
161        .elements()
162        .iter()
163        .map(|d| match d {
164            Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
165            d => Ok(usize::cast_from(u32::reinterpret_cast(d.unwrap_int32()))),
166        })
167        .collect::<Result<Vec<_>, _>>()?;
168
169    let lower_bounds = match datums.get(2) {
170        Some(d) => {
171            let arr = match d {
172                Datum::Null => return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into())),
173                o => o.unwrap_array(),
174            };
175
176            arr.elements()
177                .iter()
178                .map(|l| match l {
179                    Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
180                    l => Ok(isize::cast_from(l.unwrap_int32())),
181                })
182                .collect::<Result<Vec<_>, _>>()?
183        }
184        None => {
185            vec![1isize; dimensions.len()]
186        }
187    };
188
189    if lower_bounds.len() != dimensions.len() {
190        return Err(EvalError::ArrayFillWrongArraySubscripts);
191    }
192
193    let fill_count: usize = dimensions
194        .iter()
195        .cloned()
196        .map(Some)
197        .reduce(|a, b| match (a, b) {
198            (Some(a), Some(b)) => a.checked_mul(b),
199            _ => None,
200        })
201        .flatten()
202        .ok_or(EvalError::MaxArraySizeExceeded(MAX_SIZE))?;
203
204    if matches!(
205        mz_repr::datum_size(&fill).checked_mul(fill_count),
206        None | Some(MAX_SIZE..)
207    ) {
208        return Err(EvalError::MaxArraySizeExceeded(MAX_SIZE));
209    }
210
211    let array_dimensions = if fill_count == 0 {
212        vec![ArrayDimension {
213            lower_bound: 1,
214            length: 0,
215        }]
216    } else {
217        dimensions
218            .into_iter()
219            .zip_eq(lower_bounds)
220            .map(|(length, lower_bound)| ArrayDimension {
221                lower_bound,
222                length,
223            })
224            .collect()
225    };
226
227    Ok(temp_storage.try_make_datum(|packer| {
228        packer.try_push_array(&array_dimensions, vec![fill; fill_count])
229    })?)
230}
231
232fn array_index<'a>(datums: &[Datum<'a>], offset: i64) -> Datum<'a> {
233    mz_ore::soft_assert_no_log!(offset == 0 || offset == 1, "offset must be either 0 or 1");
234
235    let array = datums[0].unwrap_array();
236    let dims = array.dims();
237    if dims.len() != datums.len() - 1 {
238        // You missed the datums "layer"
239        return Datum::Null;
240    }
241
242    let mut final_idx = 0;
243
244    for (d, idx) in dims.into_iter().zip_eq(datums[1..].iter()) {
245        // Lower bound is written in terms of 1-based indexing, which offset accounts for.
246        let idx = isize::cast_from(idx.unwrap_int64() + offset);
247
248        let (lower, upper) = d.dimension_bounds();
249
250        // This index missed all of the data at this layer. The dimension bounds are inclusive,
251        // while range checks are exclusive, so adjust.
252        if !(lower..upper + 1).contains(&idx) {
253            return Datum::Null;
254        }
255
256        // We discover how many indices our last index represents physically.
257        final_idx *= d.length;
258
259        // Because both index and lower bound are handled in 1-based indexing, taking their
260        // difference moves us back into 0-based indexing. Similarly, if the lower bound is
261        // negative, subtracting a negative value >= to itself ensures its non-negativity.
262        final_idx += usize::try_from(idx - d.lower_bound)
263            .expect("previous bounds check ensures phsical index is at least 0");
264    }
265
266    array
267        .elements()
268        .iter()
269        .nth(final_idx)
270        .unwrap_or(Datum::Null)
271}
272
273fn array_position<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
274    let array = match datums[0] {
275        Datum::Null => return Ok(Datum::Null),
276        o => o.unwrap_array(),
277    };
278
279    if array.dims().len() > 1 {
280        return Err(EvalError::MultiDimensionalArraySearch);
281    }
282
283    let search = datums[1];
284    if search == Datum::Null {
285        return Ok(Datum::Null);
286    }
287
288    let skip: usize = match datums.get(2) {
289        Some(Datum::Null) => return Err(EvalError::MustNotBeNull("initial position".into())),
290        None => 0,
291        Some(o) => usize::try_from(o.unwrap_int32())
292            .unwrap_or(0)
293            .saturating_sub(1),
294    };
295
296    let r = array.elements().iter().skip(skip).position(|d| d == search);
297
298    Ok(Datum::from(r.map(|p| {
299        // Adjust count for the amount we skipped, plus 1 for adjustng to PG indexing scheme.
300        i32::try_from(p + skip + 1).expect("fewer than i32::MAX elements in array")
301    })))
302}
303
304// WARNING: This function has potential OOM risk!
305// It is very difficult to calculate the output size ahead of time without knowing how to
306// calculate the stringified size of each element for all possible datatypes.
307fn array_to_string<'a>(
308    datums: &[Datum<'a>],
309    elem_type: &SqlScalarType,
310    temp_storage: &'a RowArena,
311) -> Result<Datum<'a>, EvalError> {
312    if datums[0].is_null() || datums[1].is_null() {
313        return Ok(Datum::Null);
314    }
315    let array = datums[0].unwrap_array();
316    let delimiter = datums[1].unwrap_str();
317    let null_str = match datums.get(2) {
318        None | Some(Datum::Null) => None,
319        Some(d) => Some(d.unwrap_str()),
320    };
321
322    let mut out = String::new();
323    for elem in array.elements().iter() {
324        if elem.is_null() {
325            if let Some(null_str) = null_str {
326                out.push_str(null_str);
327                out.push_str(delimiter);
328            }
329        } else {
330            stringify_datum(&mut out, elem, elem_type)?;
331            out.push_str(delimiter);
332        }
333    }
334    if out.len() > 0 {
335        // Lop off last delimiter only if string is not empty
336        out.truncate(out.len() - delimiter.len());
337    }
338    Ok(Datum::String(temp_storage.push_string(out)))
339}
340
341fn coalesce<'a>(
342    datums: &[Datum<'a>],
343    temp_storage: &'a RowArena,
344    exprs: &'a [MirScalarExpr],
345) -> Result<Datum<'a>, EvalError> {
346    for e in exprs {
347        let d = e.eval(datums, temp_storage)?;
348        if !d.is_null() {
349            return Ok(d);
350        }
351    }
352    Ok(Datum::Null)
353}
354
355fn create_range<'a>(
356    datums: &[Datum<'a>],
357    temp_storage: &'a RowArena,
358) -> Result<Datum<'a>, EvalError> {
359    let flags = match datums[2] {
360        Datum::Null => {
361            return Err(EvalError::InvalidRange(
362                InvalidRangeError::NullRangeBoundFlags,
363            ));
364        }
365        o => o.unwrap_str(),
366    };
367
368    let (lower_inclusive, upper_inclusive) = parse_range_bound_flags(flags)?;
369
370    let mut range = Range::new(Some((
371        RangeBound::new(datums[0], lower_inclusive),
372        RangeBound::new(datums[1], upper_inclusive),
373    )));
374
375    range.canonicalize()?;
376
377    Ok(temp_storage.make_datum(|row| {
378        row.push_range(range).expect("errors already handled");
379    }))
380}
381
382fn date_diff_date<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
383    let unit = unit.unwrap_str();
384    let unit = unit
385        .parse()
386        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
387
388    let a = a.unwrap_date();
389    let b = b.unwrap_date();
390
391    // Convert the Date into a timestamp so we can calculate age.
392    let a_ts = CheckedTimestamp::try_from(NaiveDate::from(a).and_hms_opt(0, 0, 0).unwrap())?;
393    let b_ts = CheckedTimestamp::try_from(NaiveDate::from(b).and_hms_opt(0, 0, 0).unwrap())?;
394    let diff = b_ts.diff_as(&a_ts, unit)?;
395
396    Ok(Datum::Int64(diff))
397}
398
399fn date_diff_time<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
400    let unit = unit.unwrap_str();
401    let unit = unit
402        .parse()
403        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
404
405    let a = a.unwrap_time();
406    let b = b.unwrap_time();
407
408    // Convert the Time into a timestamp so we can calculate age.
409    let a_ts =
410        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(a))?;
411    let b_ts =
412        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(b))?;
413    let diff = b_ts.diff_as(&a_ts, unit)?;
414
415    Ok(Datum::Int64(diff))
416}
417
418fn date_diff_timestamp<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
419    let unit = unit.unwrap_str();
420    let unit = unit
421        .parse()
422        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
423
424    let a = a.unwrap_timestamp();
425    let b = b.unwrap_timestamp();
426    let diff = b.diff_as(&a, unit)?;
427
428    Ok(Datum::Int64(diff))
429}
430
431fn date_diff_timestamptz<'a>(unit: Datum, a: Datum, b: Datum) -> Result<Datum<'a>, EvalError> {
432    let unit = unit.unwrap_str();
433    let unit = unit
434        .parse()
435        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
436
437    let a = a.unwrap_timestamptz();
438    let b = b.unwrap_timestamptz();
439    let diff = b.diff_as(&a, unit)?;
440
441    Ok(Datum::Int64(diff))
442}
443
444fn error_if_null<'a>(
445    datums: &[Datum<'a>],
446    temp_storage: &'a RowArena,
447    exprs: &'a [MirScalarExpr],
448) -> Result<Datum<'a>, EvalError> {
449    let first = exprs[0].eval(datums, temp_storage)?;
450    match first {
451        Datum::Null => {
452            let err_msg = match exprs[1].eval(datums, temp_storage)? {
453                Datum::Null => {
454                    return Err(EvalError::Internal(
455                        "unexpected NULL in error side of error_if_null".into(),
456                    ));
457                }
458                o => o.unwrap_str(),
459            };
460            Err(EvalError::IfNullError(err_msg.into()))
461        }
462        _ => Ok(first),
463    }
464}
465
466fn greatest<'a>(
467    datums: &[Datum<'a>],
468    temp_storage: &'a RowArena,
469    exprs: &'a [MirScalarExpr],
470) -> Result<Datum<'a>, EvalError> {
471    let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
472    Ok(datums
473        .filter(|d| Ok(!d.is_null()))
474        .max()?
475        .unwrap_or(Datum::Null))
476}
477
478pub fn hmac_string<'a>(
479    datums: &[Datum<'a>],
480    temp_storage: &'a RowArena,
481) -> Result<Datum<'a>, EvalError> {
482    let to_digest = datums[0].unwrap_str().as_bytes();
483    let key = datums[1].unwrap_str().as_bytes();
484    let typ = datums[2].unwrap_str();
485    hmac_inner(to_digest, key, typ, temp_storage)
486}
487
488pub fn hmac_bytes<'a>(
489    datums: &[Datum<'a>],
490    temp_storage: &'a RowArena,
491) -> Result<Datum<'a>, EvalError> {
492    let to_digest = datums[0].unwrap_bytes();
493    let key = datums[1].unwrap_bytes();
494    let typ = datums[2].unwrap_str();
495    hmac_inner(to_digest, key, typ, temp_storage)
496}
497
498pub fn hmac_inner<'a>(
499    to_digest: &[u8],
500    key: &[u8],
501    typ: &str,
502    temp_storage: &'a RowArena,
503) -> Result<Datum<'a>, EvalError> {
504    let bytes = match typ {
505        "md5" => {
506            let mut mac = Hmac::<Md5>::new_from_slice(key).expect("HMAC accepts any key size");
507            mac.update(to_digest);
508            mac.finalize().into_bytes().to_vec()
509        }
510        "sha1" => {
511            let mut mac = Hmac::<Sha1>::new_from_slice(key).expect("HMAC accepts any key size");
512            mac.update(to_digest);
513            mac.finalize().into_bytes().to_vec()
514        }
515        "sha224" => {
516            let mut mac = Hmac::<Sha224>::new_from_slice(key).expect("HMAC accepts any key size");
517            mac.update(to_digest);
518            mac.finalize().into_bytes().to_vec()
519        }
520        "sha256" => {
521            let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC accepts any key size");
522            mac.update(to_digest);
523            mac.finalize().into_bytes().to_vec()
524        }
525        "sha384" => {
526            let mut mac = Hmac::<Sha384>::new_from_slice(key).expect("HMAC accepts any key size");
527            mac.update(to_digest);
528            mac.finalize().into_bytes().to_vec()
529        }
530        "sha512" => {
531            let mut mac = Hmac::<Sha512>::new_from_slice(key).expect("HMAC accepts any key size");
532            mac.update(to_digest);
533            mac.finalize().into_bytes().to_vec()
534        }
535        other => return Err(EvalError::InvalidHashAlgorithm(other.into())),
536    };
537    Ok(Datum::Bytes(temp_storage.push_bytes(bytes)))
538}
539
540fn jsonb_build_array<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
541    temp_storage.make_datum(|packer| {
542        packer.push_list(datums.into_iter().map(|d| match d {
543            Datum::Null => Datum::JsonNull,
544            d => *d,
545        }))
546    })
547}
548
549fn jsonb_build_object<'a>(
550    datums: &[Datum<'a>],
551    temp_storage: &'a RowArena,
552) -> Result<Datum<'a>, EvalError> {
553    let mut kvs = datums.chunks(2).collect::<Vec<_>>();
554    kvs.sort_by(|kv1, kv2| kv1[0].cmp(&kv2[0]));
555    kvs.dedup_by(|kv1, kv2| kv1[0] == kv2[0]);
556    temp_storage.try_make_datum(|packer| {
557        packer.push_dict_with(|packer| {
558            for kv in kvs {
559                let k = kv[0];
560                if k.is_null() {
561                    return Err(EvalError::KeyCannotBeNull);
562                };
563                let v = match kv[1] {
564                    Datum::Null => Datum::JsonNull,
565                    d => d,
566                };
567                packer.push(k);
568                packer.push(v);
569            }
570            Ok(())
571        })
572    })
573}
574
575fn least<'a>(
576    datums: &[Datum<'a>],
577    temp_storage: &'a RowArena,
578    exprs: &'a [MirScalarExpr],
579) -> Result<Datum<'a>, EvalError> {
580    let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
581    Ok(datums
582        .filter(|d| Ok(!d.is_null()))
583        .min()?
584        .unwrap_or(Datum::Null))
585}
586
587fn list_create<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
588    temp_storage.make_datum(|packer| packer.push_list(datums))
589}
590
591// TODO(benesch): remove potentially dangerous usage of `as`.
592#[allow(clippy::as_conversions)]
593fn list_index<'a>(datums: &[Datum<'a>]) -> Datum<'a> {
594    let mut buf = datums[0];
595
596    for i in datums[1..].iter() {
597        if buf.is_null() {
598            break;
599        }
600
601        let i = i.unwrap_int64();
602        if i < 1 {
603            return Datum::Null;
604        }
605
606        buf = buf
607            .unwrap_list()
608            .iter()
609            .nth(i as usize - 1)
610            .unwrap_or(Datum::Null);
611    }
612    buf
613}
614
615fn make_acl_item<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
616    let grantee = Oid(datums[0].unwrap_uint32());
617    let grantor = Oid(datums[1].unwrap_uint32());
618    let privileges = datums[2].unwrap_str();
619    let acl_mode = AclMode::parse_multiple_privileges(privileges)
620        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
621    let is_grantable = datums[3].unwrap_bool();
622    if is_grantable {
623        return Err(EvalError::Unsupported {
624            feature: "GRANT OPTION".into(),
625            discussion_no: None,
626        });
627    }
628
629    Ok(Datum::AclItem(AclItem {
630        grantee,
631        grantor,
632        acl_mode,
633    }))
634}
635
636fn make_mz_acl_item<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
637    let grantee: RoleId = datums[0]
638        .unwrap_str()
639        .parse()
640        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
641    let grantor: RoleId = datums[1]
642        .unwrap_str()
643        .parse()
644        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
645    if grantor == RoleId::Public {
646        return Err(EvalError::InvalidRoleId(
647            "mz_aclitem grantor cannot be PUBLIC role".into(),
648        ));
649    }
650    let privileges = datums[2].unwrap_str();
651    let acl_mode = AclMode::parse_multiple_privileges(privileges)
652        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
653
654    Ok(Datum::MzAclItem(MzAclItem {
655        grantee,
656        grantor,
657        acl_mode,
658    }))
659}
660
661// TODO(benesch): remove potentially dangerous usage of `as`.
662#[allow(clippy::as_conversions)]
663fn make_timestamp<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
664    let year: i32 = match datums[0].unwrap_int64().try_into() {
665        Ok(year) => year,
666        Err(_) => return Ok(Datum::Null),
667    };
668    let month: u32 = match datums[1].unwrap_int64().try_into() {
669        Ok(month) => month,
670        Err(_) => return Ok(Datum::Null),
671    };
672    let day: u32 = match datums[2].unwrap_int64().try_into() {
673        Ok(day) => day,
674        Err(_) => return Ok(Datum::Null),
675    };
676    let hour: u32 = match datums[3].unwrap_int64().try_into() {
677        Ok(day) => day,
678        Err(_) => return Ok(Datum::Null),
679    };
680    let minute: u32 = match datums[4].unwrap_int64().try_into() {
681        Ok(day) => day,
682        Err(_) => return Ok(Datum::Null),
683    };
684    let second_float = datums[5].unwrap_float64();
685    let second = second_float as u32;
686    let micros = ((second_float - second as f64) * 1_000_000.0) as u32;
687    let date = match NaiveDate::from_ymd_opt(year, month, day) {
688        Some(date) => date,
689        None => return Ok(Datum::Null),
690    };
691    let timestamp = match date.and_hms_micro_opt(hour, minute, second, micros) {
692        Some(timestamp) => timestamp,
693        None => return Ok(Datum::Null),
694    };
695    Ok(timestamp.try_into()?)
696}
697
698fn map_build<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
699    // Collect into a `BTreeMap` to provide the same semantics as it.
700    let map: std::collections::BTreeMap<&str, _> = datums
701        .into_iter()
702        .tuples()
703        .filter_map(|(k, v)| {
704            if k.is_null() {
705                None
706            } else {
707                Some((k.unwrap_str(), v))
708            }
709        })
710        .collect();
711
712    temp_storage.make_datum(|packer| packer.push_dict(map))
713}
714
715pub fn or<'a>(
716    datums: &[Datum<'a>],
717    temp_storage: &'a RowArena,
718    exprs: &'a [MirScalarExpr],
719) -> Result<Datum<'a>, EvalError> {
720    // If any is true, then return true. Else, if any is null, then return null. Else, return false.
721    let mut null = false;
722    let mut err = None;
723    for expr in exprs {
724        match expr.eval(datums, temp_storage) {
725            Ok(Datum::False) => {}
726            Ok(Datum::True) => return Ok(Datum::True), // short-circuit
727            // No return in these two cases, because we might still see a true
728            Ok(Datum::Null) => null = true,
729            Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
730            _ => unreachable!(),
731        }
732    }
733    match (err, null) {
734        (Some(err), _) => Err(err),
735        (None, true) => Ok(Datum::Null),
736        (None, false) => Ok(Datum::False),
737    }
738}
739
740fn pad_leading<'a>(
741    datums: &[Datum<'a>],
742    temp_storage: &'a RowArena,
743) -> Result<Datum<'a>, EvalError> {
744    let string = datums[0].unwrap_str();
745
746    let len = match usize::try_from(datums[1].unwrap_int32()) {
747        Ok(len) => len,
748        Err(_) => {
749            return Err(EvalError::InvalidParameterValue(
750                "length must be nonnegative".into(),
751            ));
752        }
753    };
754    if len > MAX_STRING_FUNC_RESULT_BYTES {
755        return Err(EvalError::LengthTooLarge);
756    }
757
758    let pad_string = if datums.len() == 3 {
759        datums[2].unwrap_str()
760    } else {
761        " "
762    };
763
764    let (end_char, end_char_byte_offset) = string
765        .chars()
766        .take(len)
767        .fold((0, 0), |acc, char| (acc.0 + 1, acc.1 + char.len_utf8()));
768
769    let mut buf = String::with_capacity(len);
770    if len == end_char {
771        buf.push_str(&string[0..end_char_byte_offset]);
772    } else {
773        buf.extend(pad_string.chars().cycle().take(len - end_char));
774        buf.push_str(string);
775    }
776
777    Ok(Datum::String(temp_storage.push_string(buf)))
778}
779
780fn regexp_match_dynamic<'a>(
781    datums: &[Datum<'a>],
782    temp_storage: &'a RowArena,
783) -> Result<Datum<'a>, EvalError> {
784    let haystack = datums[0];
785    let needle = datums[1].unwrap_str();
786    let flags = match datums.get(2) {
787        Some(d) => d.unwrap_str(),
788        None => "",
789    };
790    let needle = build_regex(needle, flags)?;
791    regexp_match_static(haystack, temp_storage, &needle)
792}
793
794fn regexp_split_to_array<'a>(
795    text: Datum<'a>,
796    regexp: Datum<'a>,
797    flags: Datum<'a>,
798    temp_storage: &'a RowArena,
799) -> Result<Datum<'a>, EvalError> {
800    let text = text.unwrap_str();
801    let regexp = regexp.unwrap_str();
802    let flags = flags.unwrap_str();
803    let regexp = build_regex(regexp, flags)?;
804    regexp_split_to_array_re(text, &regexp, temp_storage)
805}
806
807fn regexp_replace_dynamic<'a>(
808    datums: &[Datum<'a>],
809    temp_storage: &'a RowArena,
810) -> Result<Datum<'a>, EvalError> {
811    let source = datums[0];
812    let pattern = datums[1];
813    let replacement = datums[2];
814    let flags = match datums.get(3) {
815        Some(d) => d.unwrap_str(),
816        None => "",
817    };
818    let (limit, flags) = regexp_replace_parse_flags(flags);
819    let regexp = build_regex(pattern.unwrap_str(), &flags)?;
820    let replaced = match regexp.replacen(source.unwrap_str(), limit, replacement.unwrap_str()) {
821        Cow::Borrowed(s) => s,
822        Cow::Owned(s) => temp_storage.push_string(s),
823    };
824    Ok(Datum::String(replaced))
825}
826
827fn replace<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
828    // As a compromise to avoid always nearly duplicating the work of replace by doing size estimation,
829    // we first check if its possible for the fully replaced string to exceed the limit by assuming that
830    // every possible substring is replaced.
831    //
832    // If that estimate exceeds the limit, we then do a more precise (and expensive) estimate by counting
833    // the actual number of replacements that would occur, and using that to calculate the final size.
834    let text = datums[0].unwrap_str();
835    let from = datums[1].unwrap_str();
836    let to = datums[2].unwrap_str();
837    let possible_size = text.len() * to.len();
838    if possible_size > MAX_STRING_FUNC_RESULT_BYTES {
839        let replacement_count = text.matches(from).count();
840        let estimated_size = text.len() + replacement_count * (to.len().saturating_sub(from.len()));
841        if estimated_size > MAX_STRING_FUNC_RESULT_BYTES {
842            return Err(EvalError::LengthTooLarge);
843        }
844    }
845
846    Ok(Datum::String(
847        temp_storage.push_string(text.replace(from, to)),
848    ))
849}
850
851fn string_to_array<'a>(
852    string_datum: Datum<'a>,
853    delimiter: Datum<'a>,
854    null_string: Datum<'a>,
855    temp_storage: &'a RowArena,
856) -> Result<Datum<'a>, EvalError> {
857    if string_datum.is_null() {
858        return Ok(Datum::Null);
859    }
860
861    let string = string_datum.unwrap_str();
862
863    if string.is_empty() {
864        let mut row = Row::default();
865        let mut packer = row.packer();
866        packer.try_push_array(&[], std::iter::empty::<Datum>())?;
867
868        return Ok(temp_storage.push_unary_row(row));
869    }
870
871    if delimiter.is_null() {
872        let split_all_chars_delimiter = "";
873        return string_to_array_impl(string, split_all_chars_delimiter, null_string, temp_storage);
874    }
875
876    let delimiter = delimiter.unwrap_str();
877
878    if delimiter.is_empty() {
879        let mut row = Row::default();
880        let mut packer = row.packer();
881        packer.try_push_array(
882            &[ArrayDimension {
883                lower_bound: 1,
884                length: 1,
885            }],
886            vec![string].into_iter().map(Datum::String),
887        )?;
888
889        Ok(temp_storage.push_unary_row(row))
890    } else {
891        string_to_array_impl(string, delimiter, null_string, temp_storage)
892    }
893}
894
895fn string_to_array_impl<'a>(
896    string: &str,
897    delimiter: &str,
898    null_string: Datum<'a>,
899    temp_storage: &'a RowArena,
900) -> Result<Datum<'a>, EvalError> {
901    let mut row = Row::default();
902    let mut packer = row.packer();
903
904    let result = string.split(delimiter);
905    let found: Vec<&str> = if delimiter.is_empty() {
906        result.filter(|s| !s.is_empty()).collect()
907    } else {
908        result.collect()
909    };
910    let array_dimensions = [ArrayDimension {
911        lower_bound: 1,
912        length: found.len(),
913    }];
914
915    if null_string.is_null() {
916        packer.try_push_array(&array_dimensions, found.into_iter().map(Datum::String))?;
917    } else {
918        let null_string = null_string.unwrap_str();
919        let found_datums = found.into_iter().map(|chunk| {
920            if chunk.eq(null_string) {
921                Datum::Null
922            } else {
923                Datum::String(chunk)
924            }
925        });
926
927        packer.try_push_array(&array_dimensions, found_datums)?;
928    }
929
930    Ok(temp_storage.push_unary_row(row))
931}
932
933fn substr<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
934    let s: &'a str = datums[0].unwrap_str();
935
936    let raw_start_idx = i64::from(datums[1].unwrap_int32()) - 1;
937    let start_idx = match usize::try_from(cmp::max(raw_start_idx, 0)) {
938        Ok(i) => i,
939        Err(_) => {
940            return Err(EvalError::InvalidParameterValue(
941                format!(
942                    "substring starting index ({}) exceeds min/max position",
943                    raw_start_idx
944                )
945                .into(),
946            ));
947        }
948    };
949
950    let mut char_indices = s.char_indices();
951    let get_str_index = |(index, _char)| index;
952
953    let str_len = s.len();
954    let start_char_idx = char_indices.nth(start_idx).map_or(str_len, get_str_index);
955
956    if datums.len() == 3 {
957        let end_idx = match i64::from(datums[2].unwrap_int32()) {
958            e if e < 0 => {
959                return Err(EvalError::InvalidParameterValue(
960                    "negative substring length not allowed".into(),
961                ));
962            }
963            e if e == 0 || e + raw_start_idx < 1 => return Ok(Datum::String("")),
964            e => {
965                let e = cmp::min(raw_start_idx + e - 1, e - 1);
966                match usize::try_from(e) {
967                    Ok(i) => i,
968                    Err(_) => {
969                        return Err(EvalError::InvalidParameterValue(
970                            format!("substring length ({}) exceeds max position", e).into(),
971                        ));
972                    }
973                }
974            }
975        };
976
977        let end_char_idx = char_indices.nth(end_idx).map_or(str_len, get_str_index);
978
979        Ok(Datum::String(&s[start_char_idx..end_char_idx]))
980    } else {
981        Ok(Datum::String(&s[start_char_idx..]))
982    }
983}
984
985fn split_part<'a>(datums: &[Datum<'a>]) -> Result<Datum<'a>, EvalError> {
986    let string = datums[0].unwrap_str();
987    let delimiter = datums[1].unwrap_str();
988
989    // Provided index value begins at 1, not 0.
990    let index = match usize::try_from(i64::from(datums[2].unwrap_int32()) - 1) {
991        Ok(index) => index,
992        Err(_) => {
993            return Err(EvalError::InvalidParameterValue(
994                "field position must be greater than zero".into(),
995            ));
996        }
997    };
998
999    // If the provided delimiter is the empty string,
1000    // PostgreSQL does not break the string into individual
1001    // characters. Instead, it generates the following parts: [string].
1002    if delimiter.is_empty() {
1003        if index == 0 {
1004            return Ok(datums[0]);
1005        } else {
1006            return Ok(Datum::String(""));
1007        }
1008    }
1009
1010    // If provided index is greater than the number of split parts,
1011    // return an empty string.
1012    Ok(Datum::String(
1013        string.split(delimiter).nth(index).unwrap_or(""),
1014    ))
1015}
1016
1017fn text_concat_variadic<'a>(
1018    datums: &[Datum<'a>],
1019    temp_storage: &'a RowArena,
1020) -> Result<Datum<'a>, EvalError> {
1021    let mut total_size = 0;
1022    for d in datums {
1023        if !d.is_null() {
1024            total_size += d.unwrap_str().len();
1025            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1026                return Err(EvalError::LengthTooLarge);
1027            }
1028        }
1029    }
1030    let mut buf = String::new();
1031    for d in datums {
1032        if !d.is_null() {
1033            buf.push_str(d.unwrap_str());
1034        }
1035    }
1036    Ok(Datum::String(temp_storage.push_string(buf)))
1037}
1038
1039fn text_concat_ws<'a>(
1040    datums: &[Datum<'a>],
1041    temp_storage: &'a RowArena,
1042) -> Result<Datum<'a>, EvalError> {
1043    let ws = match datums[0] {
1044        Datum::Null => return Ok(Datum::Null),
1045        d => d.unwrap_str(),
1046    };
1047
1048    let mut total_size = 0;
1049    for d in &datums[1..] {
1050        if !d.is_null() {
1051            total_size += d.unwrap_str().len();
1052            total_size += ws.len();
1053            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1054                return Err(EvalError::LengthTooLarge);
1055            }
1056        }
1057    }
1058
1059    let buf = Itertools::join(
1060        &mut datums[1..].iter().filter_map(|d| match d {
1061            Datum::Null => None,
1062            d => Some(d.unwrap_str()),
1063        }),
1064        ws,
1065    );
1066
1067    Ok(Datum::String(temp_storage.push_string(buf)))
1068}
1069
1070fn translate<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
1071    let string = datums[0].unwrap_str();
1072    let from = datums[1].unwrap_str().chars().collect::<Vec<_>>();
1073    let to = datums[2].unwrap_str().chars().collect::<Vec<_>>();
1074
1075    Datum::String(
1076        temp_storage.push_string(
1077            string
1078                .chars()
1079                .filter_map(|c| match from.iter().position(|f| f == &c) {
1080                    Some(idx) => to.get(idx).copied(),
1081                    None => Some(c),
1082                })
1083                .collect(),
1084        ),
1085    )
1086}
1087
1088// TODO ///
1089
1090// TODO(benesch): remove potentially dangerous usage of `as`.
1091#[allow(clippy::as_conversions)]
1092fn list_slice_linear<'a>(datums: &[Datum<'a>], temp_storage: &'a RowArena) -> Datum<'a> {
1093    assert_eq!(
1094        datums.len() % 2,
1095        1,
1096        "expr::scalar::func::list_slice expects an odd number of arguments; 1 for list + 2 \
1097        for each start-end pair"
1098    );
1099    assert!(
1100        datums.len() > 2,
1101        "expr::scalar::func::list_slice expects at least 3 arguments; 1 for list + at least \
1102        one start-end pair"
1103    );
1104
1105    let mut start_idx = 0;
1106    let mut total_length = usize::MAX;
1107
1108    for (start, end) in datums[1..].iter().tuples::<(_, _)>() {
1109        let start = std::cmp::max(start.unwrap_int64(), 1);
1110        let end = end.unwrap_int64();
1111
1112        // Result should be empty list.
1113        if start > end {
1114            start_idx = 0;
1115            total_length = 0;
1116            break;
1117        }
1118
1119        let start_inner = start as usize - 1;
1120        // Start index only moves to geq positions.
1121        start_idx += start_inner;
1122
1123        // Length index only moves to leq positions
1124        let length_inner = (end - start) as usize + 1;
1125        total_length = std::cmp::min(length_inner, total_length - start_inner);
1126    }
1127
1128    let iter = datums[0]
1129        .unwrap_list()
1130        .iter()
1131        .skip(start_idx)
1132        .take(total_length);
1133
1134    temp_storage.make_datum(|row| {
1135        row.push_list_with(|row| {
1136            // if iter is empty, will get the appropriate empty list.
1137            for d in iter {
1138                row.push(d);
1139            }
1140        });
1141    })
1142}
1143
1144#[derive(
1145    Ord,
1146    PartialOrd,
1147    Clone,
1148    Debug,
1149    Eq,
1150    PartialEq,
1151    Serialize,
1152    Deserialize,
1153    Hash,
1154    MzReflect
1155)]
1156pub enum VariadicFunc {
1157    Coalesce,
1158    Greatest,
1159    Least,
1160    Concat,
1161    ConcatWs,
1162    MakeTimestamp,
1163    PadLeading,
1164    Substr,
1165    Replace,
1166    JsonbBuildArray,
1167    JsonbBuildObject,
1168    MapBuild {
1169        value_type: SqlScalarType,
1170    },
1171    ArrayCreate {
1172        // We need to know the element type to type empty arrays.
1173        elem_type: SqlScalarType,
1174    },
1175    ArrayToString {
1176        elem_type: SqlScalarType,
1177    },
1178    ArrayIndex {
1179        // Adjusts the index by offset depending on whether being called on an array or an
1180        // Int2Vector.
1181        offset: i64,
1182    },
1183    ListCreate {
1184        // We need to know the element type to type empty lists.
1185        elem_type: SqlScalarType,
1186    },
1187    RecordCreate {
1188        field_names: Vec<ColumnName>,
1189    },
1190    ListIndex,
1191    ListSliceLinear,
1192    SplitPart,
1193    RegexpMatch,
1194    HmacString,
1195    HmacBytes,
1196    ErrorIfNull,
1197    DateBinTimestamp,
1198    DateBinTimestampTz,
1199    DateDiffTimestamp,
1200    DateDiffTimestampTz,
1201    DateDiffDate,
1202    DateDiffTime,
1203    And,
1204    Or,
1205    RangeCreate {
1206        elem_type: SqlScalarType,
1207    },
1208    MakeAclItem,
1209    MakeMzAclItem,
1210    Translate,
1211    ArrayPosition,
1212    ArrayFill {
1213        elem_type: SqlScalarType,
1214    },
1215    StringToArray,
1216    TimezoneTime,
1217    RegexpSplitToArray,
1218    RegexpReplace,
1219}
1220
1221impl VariadicFunc {
1222    pub fn eval<'a>(
1223        &'a self,
1224        datums: &[Datum<'a>],
1225        temp_storage: &'a RowArena,
1226        exprs: &'a [MirScalarExpr],
1227    ) -> Result<Datum<'a>, EvalError> {
1228        // Evaluate all non-eager functions directly
1229        match self {
1230            VariadicFunc::Coalesce => return coalesce(datums, temp_storage, exprs),
1231            VariadicFunc::Greatest => return greatest(datums, temp_storage, exprs),
1232            VariadicFunc::And => return and(datums, temp_storage, exprs),
1233            VariadicFunc::Or => return or(datums, temp_storage, exprs),
1234            VariadicFunc::ErrorIfNull => return error_if_null(datums, temp_storage, exprs),
1235            VariadicFunc::Least => return least(datums, temp_storage, exprs),
1236            _ => {}
1237        };
1238
1239        // Compute parameters to eager functions
1240        let ds = exprs
1241            .iter()
1242            .map(|e| e.eval(datums, temp_storage))
1243            .collect::<Result<Vec<_>, _>>()?;
1244        // Check NULL propagation
1245        if self.propagates_nulls() && ds.iter().any(|d| d.is_null()) {
1246            return Ok(Datum::Null);
1247        }
1248
1249        // Evaluate eager functions
1250        match self {
1251            VariadicFunc::Coalesce
1252            | VariadicFunc::Greatest
1253            | VariadicFunc::And
1254            | VariadicFunc::Or
1255            | VariadicFunc::ErrorIfNull
1256            | VariadicFunc::Least => unreachable!(),
1257            VariadicFunc::Concat => text_concat_variadic(&ds, temp_storage),
1258            VariadicFunc::ConcatWs => text_concat_ws(&ds, temp_storage),
1259            VariadicFunc::MakeTimestamp => make_timestamp(&ds),
1260            VariadicFunc::PadLeading => pad_leading(&ds, temp_storage),
1261            VariadicFunc::Substr => substr(&ds),
1262            VariadicFunc::Replace => replace(&ds, temp_storage),
1263            VariadicFunc::Translate => Ok(translate(&ds, temp_storage)),
1264            VariadicFunc::JsonbBuildArray => Ok(jsonb_build_array(&ds, temp_storage)),
1265            VariadicFunc::JsonbBuildObject => jsonb_build_object(&ds, temp_storage),
1266            VariadicFunc::MapBuild { .. } => Ok(map_build(&ds, temp_storage)),
1267            VariadicFunc::ArrayCreate {
1268                elem_type: SqlScalarType::Array(_),
1269            } => array_create_multidim(&ds, temp_storage),
1270            VariadicFunc::ArrayCreate { .. } => array_create_scalar(&ds, temp_storage),
1271            VariadicFunc::ArrayToString { elem_type } => {
1272                array_to_string(&ds, elem_type, temp_storage)
1273            }
1274            VariadicFunc::ArrayIndex { offset } => Ok(array_index(&ds, *offset)),
1275
1276            VariadicFunc::ListCreate { .. } | VariadicFunc::RecordCreate { .. } => {
1277                Ok(list_create(&ds, temp_storage))
1278            }
1279            VariadicFunc::ListIndex => Ok(list_index(&ds)),
1280            VariadicFunc::ListSliceLinear => Ok(list_slice_linear(&ds, temp_storage)),
1281            VariadicFunc::SplitPart => split_part(&ds),
1282            VariadicFunc::RegexpMatch => regexp_match_dynamic(&ds, temp_storage),
1283            VariadicFunc::HmacString => hmac_string(&ds, temp_storage),
1284            VariadicFunc::HmacBytes => hmac_bytes(&ds, temp_storage),
1285            VariadicFunc::DateBinTimestamp => date_bin(
1286                ds[0].unwrap_interval(),
1287                ds[1].unwrap_timestamp(),
1288                ds[2].unwrap_timestamp(),
1289            )
1290            .into_result(temp_storage),
1291            VariadicFunc::DateBinTimestampTz => date_bin(
1292                ds[0].unwrap_interval(),
1293                ds[1].unwrap_timestamptz(),
1294                ds[2].unwrap_timestamptz(),
1295            )
1296            .into_result(temp_storage),
1297            VariadicFunc::DateDiffTimestamp => date_diff_timestamp(ds[0], ds[1], ds[2]),
1298            VariadicFunc::DateDiffTimestampTz => date_diff_timestamptz(ds[0], ds[1], ds[2]),
1299            VariadicFunc::DateDiffDate => date_diff_date(ds[0], ds[1], ds[2]),
1300            VariadicFunc::DateDiffTime => date_diff_time(ds[0], ds[1], ds[2]),
1301            VariadicFunc::RangeCreate { .. } => create_range(&ds, temp_storage),
1302            VariadicFunc::MakeAclItem => make_acl_item(&ds),
1303            VariadicFunc::MakeMzAclItem => make_mz_acl_item(&ds),
1304            VariadicFunc::ArrayPosition => array_position(&ds),
1305            VariadicFunc::ArrayFill { .. } => array_fill(&ds, temp_storage),
1306            VariadicFunc::TimezoneTime => parse_timezone(ds[0].unwrap_str(), TimezoneSpec::Posix)
1307                .map(|tz| {
1308                    timezone_time(
1309                        tz,
1310                        ds[1].unwrap_time(),
1311                        &ds[2].unwrap_timestamptz().naive_utc(),
1312                    )
1313                    .into()
1314                }),
1315            VariadicFunc::RegexpSplitToArray => {
1316                let flags = if ds.len() == 2 {
1317                    Datum::String("")
1318                } else {
1319                    ds[2]
1320                };
1321                regexp_split_to_array(ds[0], ds[1], flags, temp_storage)
1322            }
1323            VariadicFunc::RegexpReplace => regexp_replace_dynamic(&ds, temp_storage),
1324            VariadicFunc::StringToArray => {
1325                let null_string = if ds.len() == 2 { Datum::Null } else { ds[2] };
1326
1327                string_to_array(ds[0], ds[1], null_string, temp_storage)
1328            }
1329        }
1330    }
1331
1332    pub fn is_associative(&self) -> bool {
1333        match self {
1334            VariadicFunc::Coalesce
1335            | VariadicFunc::Greatest
1336            | VariadicFunc::Least
1337            | VariadicFunc::Concat
1338            | VariadicFunc::And
1339            | VariadicFunc::Or => true,
1340
1341            VariadicFunc::MakeTimestamp
1342            | VariadicFunc::PadLeading
1343            | VariadicFunc::ConcatWs
1344            | VariadicFunc::Substr
1345            | VariadicFunc::Replace
1346            | VariadicFunc::Translate
1347            | VariadicFunc::JsonbBuildArray
1348            | VariadicFunc::JsonbBuildObject
1349            | VariadicFunc::MapBuild { value_type: _ }
1350            | VariadicFunc::ArrayCreate { elem_type: _ }
1351            | VariadicFunc::ArrayToString { elem_type: _ }
1352            | VariadicFunc::ArrayIndex { offset: _ }
1353            | VariadicFunc::ListCreate { elem_type: _ }
1354            | VariadicFunc::RecordCreate { field_names: _ }
1355            | VariadicFunc::ListIndex
1356            | VariadicFunc::ListSliceLinear
1357            | VariadicFunc::SplitPart
1358            | VariadicFunc::RegexpMatch
1359            | VariadicFunc::HmacString
1360            | VariadicFunc::HmacBytes
1361            | VariadicFunc::ErrorIfNull
1362            | VariadicFunc::DateBinTimestamp
1363            | VariadicFunc::DateBinTimestampTz
1364            | VariadicFunc::DateDiffTimestamp
1365            | VariadicFunc::DateDiffTimestampTz
1366            | VariadicFunc::DateDiffDate
1367            | VariadicFunc::DateDiffTime
1368            | VariadicFunc::RangeCreate { .. }
1369            | VariadicFunc::MakeAclItem
1370            | VariadicFunc::MakeMzAclItem
1371            | VariadicFunc::ArrayPosition
1372            | VariadicFunc::ArrayFill { .. }
1373            | VariadicFunc::TimezoneTime
1374            | VariadicFunc::RegexpSplitToArray
1375            | VariadicFunc::StringToArray
1376            | VariadicFunc::RegexpReplace => false,
1377        }
1378    }
1379
1380    pub fn output_type(&self, input_types: Vec<SqlColumnType>) -> SqlColumnType {
1381        use VariadicFunc::*;
1382        let in_nullable = input_types.iter().any(|t| t.nullable);
1383        match self {
1384            Greatest | Least => input_types.into_iter().reduce(|l, r| l.union(&r)).unwrap(),
1385            Coalesce => {
1386                // Note that the parser doesn't allow empty argument lists for variadic functions
1387                // that use the standard function call syntax (ArrayCreate and co. are different
1388                // because of the special syntax for calling them).
1389                let nullable = input_types.iter().all(|typ| typ.nullable);
1390                input_types
1391                    .into_iter()
1392                    .reduce(|l, r| l.union(&r))
1393                    .unwrap()
1394                    .nullable(nullable)
1395            }
1396            Concat | ConcatWs => SqlScalarType::String.nullable(in_nullable),
1397            MakeTimestamp => SqlScalarType::Timestamp { precision: None }.nullable(true),
1398            PadLeading => SqlScalarType::String.nullable(in_nullable),
1399            Substr => SqlScalarType::String.nullable(in_nullable),
1400            Replace => SqlScalarType::String.nullable(in_nullable),
1401            Translate => SqlScalarType::String.nullable(in_nullable),
1402            JsonbBuildArray | JsonbBuildObject => SqlScalarType::Jsonb.nullable(true),
1403            MapBuild { value_type } => SqlScalarType::Map {
1404                value_type: Box::new(value_type.clone()),
1405                custom_id: None,
1406            }
1407            .nullable(true),
1408            ArrayCreate { elem_type } => {
1409                soft_assert_or_log!(
1410                    input_types.iter().all(|t| {
1411                        // This ensures that the types are compatiable, but nullability may vary deeply in the types.
1412                        ReprScalarType::from(elem_type)
1413                            .union(&ReprScalarType::from(&t.scalar_type))
1414                            .is_ok()
1415                    }),
1416                    "Args to ArrayCreate should have types that are repr-compatible with the elem_type.\nArgs:{input_types:#?}\nelem_type:{elem_type:#?}"
1417                );
1418                match elem_type {
1419                    SqlScalarType::Array(_) => elem_type.clone().nullable(false),
1420                    _ => SqlScalarType::Array(Box::new(elem_type.clone())).nullable(false),
1421                }
1422            }
1423            ArrayToString { .. } => SqlScalarType::String.nullable(in_nullable),
1424            ArrayIndex { .. } => input_types[0]
1425                .scalar_type
1426                .unwrap_array_element_type()
1427                .clone()
1428                .nullable(true),
1429            ListCreate { elem_type } => {
1430                soft_assert_or_log!(
1431                    input_types.iter().all(|t| {
1432                        // This ensures that the types are compatiable, but nullability may vary deeply in the types.
1433                        ReprScalarType::from(elem_type)
1434                            .union(&ReprScalarType::from(&t.scalar_type))
1435                            .is_ok()
1436                    }),
1437                    "Args to ListCreate should have types that are compatible with the elem_type.\nArgs:{input_types:#?}\nelem_type:{elem_type:#?}"
1438                );
1439                SqlScalarType::List {
1440                    element_type: Box::new(elem_type.clone()),
1441                    custom_id: None,
1442                }
1443                .nullable(false)
1444            }
1445            ListIndex => input_types[0]
1446                .scalar_type
1447                .unwrap_list_nth_layer_type(input_types.len() - 1)
1448                .clone()
1449                .nullable(true),
1450            ListSliceLinear { .. } => input_types[0].scalar_type.clone().nullable(in_nullable),
1451            RecordCreate { field_names } => SqlScalarType::Record {
1452                fields: field_names
1453                    .clone()
1454                    .into_iter()
1455                    .zip_eq(input_types)
1456                    .collect(),
1457                custom_id: None,
1458            }
1459            .nullable(false),
1460            SplitPart => SqlScalarType::String.nullable(in_nullable),
1461            RegexpMatch => SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true),
1462            HmacString | HmacBytes => SqlScalarType::Bytes.nullable(in_nullable),
1463            ErrorIfNull => input_types[0].scalar_type.clone().nullable(false),
1464            DateBinTimestamp => SqlScalarType::Timestamp { precision: None }.nullable(in_nullable),
1465            DateBinTimestampTz => {
1466                SqlScalarType::TimestampTz { precision: None }.nullable(in_nullable)
1467            }
1468            DateDiffTimestamp => SqlScalarType::Int64.nullable(in_nullable),
1469            DateDiffTimestampTz => SqlScalarType::Int64.nullable(in_nullable),
1470            DateDiffDate => SqlScalarType::Int64.nullable(in_nullable),
1471            DateDiffTime => SqlScalarType::Int64.nullable(in_nullable),
1472            And | Or => SqlScalarType::Bool.nullable(in_nullable),
1473            RangeCreate { elem_type } => SqlScalarType::Range {
1474                element_type: Box::new(elem_type.clone()),
1475            }
1476            .nullable(false),
1477            MakeAclItem => SqlScalarType::AclItem.nullable(true),
1478            MakeMzAclItem => SqlScalarType::MzAclItem.nullable(true),
1479            ArrayPosition => SqlScalarType::Int32.nullable(true),
1480            ArrayFill { elem_type } => {
1481                SqlScalarType::Array(Box::new(elem_type.clone())).nullable(false)
1482            }
1483            TimezoneTime => SqlScalarType::Time.nullable(in_nullable),
1484            RegexpSplitToArray => {
1485                SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(in_nullable)
1486            }
1487            RegexpReplace => SqlScalarType::String.nullable(in_nullable),
1488            StringToArray => SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true),
1489        }
1490    }
1491
1492    /// Whether the function output is NULL if any of its inputs are NULL.
1493    ///
1494    /// NB: if any input is NULL the output will be returned as NULL without
1495    /// calling the function.
1496    pub fn propagates_nulls(&self) -> bool {
1497        // NOTE: The following is a list of the variadic functions
1498        // that **DO NOT** propagate nulls.
1499        !matches!(
1500            self,
1501            VariadicFunc::And
1502                | VariadicFunc::Or
1503                | VariadicFunc::Coalesce
1504                | VariadicFunc::Greatest
1505                | VariadicFunc::Least
1506                | VariadicFunc::Concat
1507                | VariadicFunc::ConcatWs
1508                | VariadicFunc::JsonbBuildArray
1509                | VariadicFunc::JsonbBuildObject
1510                | VariadicFunc::MapBuild { .. }
1511                | VariadicFunc::ListCreate { .. }
1512                | VariadicFunc::RecordCreate { .. }
1513                | VariadicFunc::ArrayCreate { .. }
1514                | VariadicFunc::ArrayToString { .. }
1515                | VariadicFunc::ErrorIfNull
1516                | VariadicFunc::RangeCreate { .. }
1517                | VariadicFunc::ArrayPosition
1518                | VariadicFunc::ArrayFill { .. }
1519                | VariadicFunc::StringToArray
1520        )
1521    }
1522
1523    /// Whether the function might return NULL even if none of its inputs are
1524    /// NULL.
1525    ///
1526    /// This is presently conservative, and may indicate that a function
1527    /// introduces nulls even when it does not.
1528    pub fn introduces_nulls(&self) -> bool {
1529        use VariadicFunc::*;
1530        match self {
1531            Concat
1532            | ConcatWs
1533            | PadLeading
1534            | Substr
1535            | Replace
1536            | Translate
1537            | JsonbBuildArray
1538            | JsonbBuildObject
1539            | MapBuild { .. }
1540            | ArrayCreate { .. }
1541            | ArrayToString { .. }
1542            | ListCreate { .. }
1543            | RecordCreate { .. }
1544            | ListSliceLinear
1545            | SplitPart
1546            | HmacString
1547            | HmacBytes
1548            | ErrorIfNull
1549            | DateBinTimestamp
1550            | DateBinTimestampTz
1551            | DateDiffTimestamp
1552            | DateDiffTimestampTz
1553            | DateDiffDate
1554            | DateDiffTime
1555            | RangeCreate { .. }
1556            | And
1557            | Or
1558            | MakeAclItem
1559            | MakeMzAclItem
1560            | ArrayPosition
1561            | ArrayFill { .. }
1562            | TimezoneTime
1563            | RegexpSplitToArray
1564            | RegexpReplace => false,
1565            Coalesce
1566            | Greatest
1567            | Least
1568            | MakeTimestamp
1569            | ArrayIndex { .. }
1570            | StringToArray
1571            | ListIndex
1572            | RegexpMatch => true,
1573        }
1574    }
1575
1576    pub fn switch_and_or(&self) -> Self {
1577        match self {
1578            VariadicFunc::And => VariadicFunc::Or,
1579            VariadicFunc::Or => VariadicFunc::And,
1580            _ => unreachable!(),
1581        }
1582    }
1583
1584    pub fn is_infix_op(&self) -> bool {
1585        use VariadicFunc::*;
1586        matches!(self, And | Or)
1587    }
1588
1589    /// Gives the unit (u) of OR or AND, such that `u AND/OR x == x`.
1590    /// Note that a 0-arg AND/OR evaluates to unit_of_and_or.
1591    pub fn unit_of_and_or(&self) -> MirScalarExpr {
1592        match self {
1593            VariadicFunc::And => MirScalarExpr::literal_true(),
1594            VariadicFunc::Or => MirScalarExpr::literal_false(),
1595            _ => unreachable!(),
1596        }
1597    }
1598
1599    /// Gives the zero (z) of OR or AND, such that `z AND/OR x == z`.
1600    pub fn zero_of_and_or(&self) -> MirScalarExpr {
1601        match self {
1602            VariadicFunc::And => MirScalarExpr::literal_false(),
1603            VariadicFunc::Or => MirScalarExpr::literal_true(),
1604            _ => unreachable!(),
1605        }
1606    }
1607
1608    /// Returns true if the function could introduce an error on non-error inputs.
1609    pub fn could_error(&self) -> bool {
1610        match self {
1611            VariadicFunc::And | VariadicFunc::Or => false,
1612            VariadicFunc::Coalesce => false,
1613            VariadicFunc::Greatest | VariadicFunc::Least => false,
1614            VariadicFunc::Concat | VariadicFunc::ConcatWs => false,
1615            VariadicFunc::Replace => false,
1616            VariadicFunc::Translate => false,
1617            VariadicFunc::ArrayIndex { .. } => false,
1618            VariadicFunc::ListCreate { .. } | VariadicFunc::RecordCreate { .. } => false,
1619            // All other cases are unknown
1620            _ => true,
1621        }
1622    }
1623
1624    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
1625    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
1626    /// determine the range of possible outputs just by mapping the endpoints.
1627    ///
1628    /// This describes the *pointwise* behaviour of the function:
1629    /// ie. if more than one argument is provided, this describes the behaviour of
1630    /// any specific argument as the others are held constant. (For example, `COALESCE(a, b)` is
1631    /// monotone in `a` because for any particular value of `b`, increasing `a` will never
1632    /// cause the result to decrease.)
1633    ///
1634    /// This property describes the behaviour of the function over ranges where the function is defined:
1635    /// ie. the arguments and the result are non-error datums.
1636    pub fn is_monotone(&self) -> bool {
1637        match self {
1638            VariadicFunc::Coalesce
1639            | VariadicFunc::Greatest
1640            | VariadicFunc::Least
1641            | VariadicFunc::And
1642            | VariadicFunc::Or => true,
1643            VariadicFunc::Concat
1644            | VariadicFunc::ConcatWs
1645            | VariadicFunc::MakeTimestamp
1646            | VariadicFunc::PadLeading
1647            | VariadicFunc::Substr
1648            | VariadicFunc::Replace
1649            | VariadicFunc::JsonbBuildArray
1650            | VariadicFunc::JsonbBuildObject
1651            | VariadicFunc::MapBuild { .. }
1652            | VariadicFunc::ArrayCreate { .. }
1653            | VariadicFunc::ArrayToString { .. }
1654            | VariadicFunc::ArrayIndex { .. }
1655            | VariadicFunc::ListCreate { .. }
1656            | VariadicFunc::RecordCreate { .. }
1657            | VariadicFunc::ListIndex
1658            | VariadicFunc::ListSliceLinear
1659            | VariadicFunc::SplitPart
1660            | VariadicFunc::RegexpMatch
1661            | VariadicFunc::HmacString
1662            | VariadicFunc::HmacBytes
1663            | VariadicFunc::ErrorIfNull
1664            | VariadicFunc::DateBinTimestamp
1665            | VariadicFunc::DateBinTimestampTz
1666            | VariadicFunc::RangeCreate { .. }
1667            | VariadicFunc::MakeAclItem
1668            | VariadicFunc::MakeMzAclItem
1669            | VariadicFunc::Translate
1670            | VariadicFunc::ArrayPosition
1671            | VariadicFunc::ArrayFill { .. }
1672            | VariadicFunc::DateDiffTimestamp
1673            | VariadicFunc::DateDiffTimestampTz
1674            | VariadicFunc::DateDiffDate
1675            | VariadicFunc::DateDiffTime
1676            | VariadicFunc::TimezoneTime
1677            | VariadicFunc::RegexpSplitToArray
1678            | VariadicFunc::StringToArray
1679            | VariadicFunc::RegexpReplace => false,
1680        }
1681    }
1682}
1683
1684impl std::fmt::Display for VariadicFunc {
1685    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1686        match self {
1687            VariadicFunc::Coalesce => f.write_str("coalesce"),
1688            VariadicFunc::Greatest => f.write_str("greatest"),
1689            VariadicFunc::Least => f.write_str("least"),
1690            VariadicFunc::Concat => f.write_str("concat"),
1691            VariadicFunc::ConcatWs => f.write_str("concat_ws"),
1692            VariadicFunc::MakeTimestamp => f.write_str("makets"),
1693            VariadicFunc::PadLeading => f.write_str("lpad"),
1694            VariadicFunc::Substr => f.write_str("substr"),
1695            VariadicFunc::Replace => f.write_str("replace"),
1696            VariadicFunc::Translate => f.write_str("translate"),
1697            VariadicFunc::JsonbBuildArray => f.write_str("jsonb_build_array"),
1698            VariadicFunc::JsonbBuildObject => f.write_str("jsonb_build_object"),
1699            VariadicFunc::MapBuild { .. } => f.write_str("map_build"),
1700            VariadicFunc::ArrayCreate { .. } => f.write_str("array_create"),
1701            VariadicFunc::ArrayToString { .. } => f.write_str("array_to_string"),
1702            VariadicFunc::ArrayIndex { .. } => f.write_str("array_index"),
1703            VariadicFunc::ListCreate { .. } => f.write_str("list_create"),
1704            VariadicFunc::RecordCreate { .. } => f.write_str("record_create"),
1705            VariadicFunc::ListIndex => f.write_str("list_index"),
1706            VariadicFunc::ListSliceLinear => f.write_str("list_slice_linear"),
1707            VariadicFunc::SplitPart => f.write_str("split_string"),
1708            VariadicFunc::RegexpMatch => f.write_str("regexp_match"),
1709            VariadicFunc::HmacString | VariadicFunc::HmacBytes => f.write_str("hmac"),
1710            VariadicFunc::ErrorIfNull => f.write_str("error_if_null"),
1711            VariadicFunc::DateBinTimestamp => f.write_str("timestamp_bin"),
1712            VariadicFunc::DateBinTimestampTz => f.write_str("timestamptz_bin"),
1713            VariadicFunc::DateDiffTimestamp
1714            | VariadicFunc::DateDiffTimestampTz
1715            | VariadicFunc::DateDiffDate
1716            | VariadicFunc::DateDiffTime => f.write_str("datediff"),
1717            VariadicFunc::And => f.write_str("AND"),
1718            VariadicFunc::Or => f.write_str("OR"),
1719            VariadicFunc::RangeCreate {
1720                elem_type: element_type,
1721            } => f.write_str(match element_type {
1722                SqlScalarType::Int32 => "int4range",
1723                SqlScalarType::Int64 => "int8range",
1724                SqlScalarType::Date => "daterange",
1725                SqlScalarType::Numeric { .. } => "numrange",
1726                SqlScalarType::Timestamp { .. } => "tsrange",
1727                SqlScalarType::TimestampTz { .. } => "tstzrange",
1728                _ => unreachable!(),
1729            }),
1730            VariadicFunc::MakeAclItem => f.write_str("makeaclitem"),
1731            VariadicFunc::MakeMzAclItem => f.write_str("make_mz_aclitem"),
1732            VariadicFunc::ArrayPosition => f.write_str("array_position"),
1733            VariadicFunc::ArrayFill { .. } => f.write_str("array_fill"),
1734            VariadicFunc::TimezoneTime => f.write_str("timezonet"),
1735            VariadicFunc::RegexpSplitToArray => f.write_str("regexp_split_to_array"),
1736            VariadicFunc::RegexpReplace => f.write_str("regexp_replace"),
1737            VariadicFunc::StringToArray => f.write_str("string_to_array"),
1738        }
1739    }
1740}