Skip to main content

mz_expr/scalar/func/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Variadic functions.
15
16use std::borrow::Cow;
17use std::cmp;
18use std::fmt;
19
20use aws_lc_rs::hmac as aws_hmac;
21use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
22use fallible_iterator::FallibleIterator;
23use hmac::{Hmac, Mac};
24use itertools::Itertools;
25use md5::Md5;
26use mz_expr_derive::sqlfunc;
27use mz_lowertest::MzReflect;
28use mz_ore::cast::{CastFrom, ReinterpretCast};
29use mz_pgtz::timezone::TimezoneSpec;
30use mz_repr::ReprColumnType;
31use mz_repr::adt::array::{Array, ArrayDimension, ArrayDimensions, InvalidArrayError};
32use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem};
33use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags};
34use mz_repr::adt::system::Oid;
35use mz_repr::adt::timestamp::CheckedTimestamp;
36use mz_repr::role_id::RoleId;
37use mz_repr::{
38    ColumnName, Datum, DatumList, FromDatum, InputDatumType, OptionalArg, OutputDatumType, Row,
39    RowArena, SqlColumnType, SqlScalarType, Variadic,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::func::CaseLiteral;
44use crate::func::{
45    MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, parse_timezone,
46    regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re, stringify_datum,
47    timezone_time,
48};
49use crate::{EvalError, MirScalarExpr};
50use mz_repr::adt::date::Date;
51use mz_repr::adt::interval::Interval;
52use mz_repr::adt::jsonb::JsonbRef;
53
54#[derive(
55    Ord,
56    PartialOrd,
57    Clone,
58    Debug,
59    Eq,
60    PartialEq,
61    Serialize,
62    Deserialize,
63    Hash,
64    MzReflect
65)]
66pub struct And;
67
68impl fmt::Display for And {
69    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70        f.write_str("AND")
71    }
72}
73
74impl LazyVariadicFunc for And {
75    fn eval<'a>(
76        &'a self,
77        datums: &[Datum<'a>],
78        temp_storage: &'a RowArena,
79        exprs: &'a [MirScalarExpr],
80    ) -> Result<Datum<'a>, EvalError> {
81        // If any is false, then return false. Else, if any is null, then return null. Else, return true.
82        let mut null = false;
83        let mut err = None;
84        for expr in exprs {
85            match expr.eval(datums, temp_storage) {
86                Ok(Datum::False) => return Ok(Datum::False), // short-circuit
87                Ok(Datum::True) => {}
88                // No return in these two cases, because we might still see a false
89                Ok(Datum::Null) => null = true,
90                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
91                _ => unreachable!(),
92            }
93        }
94        match (err, null) {
95            (Some(err), _) => Err(err),
96            (None, true) => Ok(Datum::Null),
97            (None, false) => Ok(Datum::True),
98        }
99    }
100
101    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
102        let in_nullable = input_types.iter().any(|t| t.nullable);
103        SqlScalarType::Bool.nullable(in_nullable)
104    }
105
106    fn propagates_nulls(&self) -> bool {
107        false
108    }
109
110    fn introduces_nulls(&self) -> bool {
111        false
112    }
113
114    fn could_error(&self) -> bool {
115        false
116    }
117
118    fn is_monotone(&self) -> bool {
119        true
120    }
121
122    fn is_associative(&self) -> bool {
123        true
124    }
125
126    fn is_infix_op(&self) -> bool {
127        true
128    }
129}
130
131#[derive(
132    Ord,
133    PartialOrd,
134    Clone,
135    Debug,
136    Eq,
137    PartialEq,
138    Serialize,
139    Deserialize,
140    Hash,
141    MzReflect
142)]
143pub struct ArrayCreate {
144    pub elem_type: SqlScalarType,
145}
146
147/// Constructs a new multidimensional array out of an arbitrary number of
148/// lower-dimensional arrays.
149///
150/// For example, if given three 1D arrays of length 2, this function will
151/// construct a 2D array with dimensions 3x2.
152///
153/// The input datums in `datums` must all be arrays of the same dimensions.
154/// (The arrays must also be of the same element type, but that is checked by
155/// the SQL type system, rather than checked here at runtime.)
156///
157/// If all input arrays are zero-dimensional arrays, then the output is a zero-
158/// dimensional array. Otherwise, the lower bound of the additional dimension is
159/// one and the length of the new dimension is equal to `datums.len()`.
160///
161/// Null elements are allowed and considered to be zero-dimensional arrays.
162#[sqlfunc(
163    ArrayCreate,
164    output_type_expr = "match &self.elem_type { SqlScalarType::Array(_) => self.elem_type.clone().nullable(false), _ => SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false) }",
165    introduces_nulls = false
166)]
167fn array_create<'a>(
168    &self,
169    datums: Variadic<Datum<'a>>,
170    temp_storage: &'a RowArena,
171) -> Result<Datum<'a>, EvalError> {
172    match &self.elem_type {
173        SqlScalarType::Array(_) => array_create_multidim(&datums, temp_storage),
174        _ => array_create_scalar(&datums, temp_storage),
175    }
176}
177fn array_create_multidim<'a>(
178    datums: &[Datum<'a>],
179    temp_storage: &'a RowArena,
180) -> Result<Datum<'a>, EvalError> {
181    let mut dim: Option<ArrayDimensions> = None;
182    for datum in datums {
183        let actual_dims = match datum {
184            Datum::Null => ArrayDimensions::default(),
185            Datum::Array(arr) => arr.dims(),
186            d => panic!("unexpected datum {d}"),
187        };
188        if let Some(expected) = &dim {
189            if actual_dims.ndims() != expected.ndims() {
190                let actual = actual_dims.ndims().into();
191                let expected = expected.ndims().into();
192                // All input arrays must have the same dimensionality.
193                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
194            }
195            if let Some((e, a)) = expected
196                .into_iter()
197                .zip_eq(actual_dims)
198                .find(|(e, a)| e != a)
199            {
200                let actual = a.length;
201                let expected = e.length;
202                // All input arrays must have the same dimensionality.
203                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
204            }
205        }
206        dim = Some(actual_dims);
207    }
208    // Per PostgreSQL, if all input arrays are zero dimensional, so is the output.
209    if dim.as_ref().map_or(true, ArrayDimensions::is_empty) {
210        return Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&[], &[]))?);
211    }
212
213    let mut dims = vec![ArrayDimension {
214        lower_bound: 1,
215        length: datums.len(),
216    }];
217    if let Some(d) = datums.first() {
218        dims.extend(d.unwrap_array().dims());
219    };
220    let elements = datums
221        .iter()
222        .flat_map(|d| d.unwrap_array().elements().iter());
223    let datum =
224        temp_storage.try_make_datum(move |packer| packer.try_push_array(&dims, elements))?;
225    Ok(datum)
226}
227
228#[derive(
229    Ord,
230    PartialOrd,
231    Clone,
232    Debug,
233    Eq,
234    PartialEq,
235    Serialize,
236    Deserialize,
237    Hash,
238    MzReflect
239)]
240pub struct ArrayFill {
241    pub elem_type: SqlScalarType,
242}
243
244#[sqlfunc(
245    ArrayFill,
246    output_type_expr = "SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false)",
247    introduces_nulls = false
248)]
249fn array_fill<'a>(
250    &self,
251    fill: Datum<'a>,
252    dims: Option<Array<'a>>,
253    lower_bounds: OptionalArg<Option<Array<'a>>>,
254    temp_storage: &'a RowArena,
255) -> Result<Datum<'a>, EvalError> {
256    const MAX_SIZE: usize = 1 << 28 - 1;
257    const NULL_ARR_ERR: &str = "dimension array or low bound array";
258    const NULL_ELEM_ERR: &str = "dimension values";
259
260    if matches!(fill, Datum::Array(_)) {
261        return Err(EvalError::Unsupported {
262            feature: "array_fill with arrays".into(),
263            discussion_no: None,
264        });
265    }
266
267    let Some(arr) = dims else {
268        return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
269    };
270
271    let dimensions = arr
272        .elements()
273        .iter()
274        .map(|d| match d {
275            Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
276            d => Ok(usize::cast_from(u32::reinterpret_cast(d.unwrap_int32()))),
277        })
278        .collect::<Result<Vec<_>, _>>()?;
279
280    let lower_bounds = match *lower_bounds {
281        Some(d) => {
282            let Some(arr) = d else {
283                return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
284            };
285
286            arr.elements()
287                .iter()
288                .map(|l| match l {
289                    Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
290                    l => Ok(isize::cast_from(l.unwrap_int32())),
291                })
292                .collect::<Result<Vec<_>, _>>()?
293        }
294        None => {
295            vec![1isize; dimensions.len()]
296        }
297    };
298
299    if lower_bounds.len() != dimensions.len() {
300        return Err(EvalError::ArrayFillWrongArraySubscripts);
301    }
302
303    let fill_count: usize = dimensions
304        .iter()
305        .cloned()
306        .map(Some)
307        .reduce(|a, b| match (a, b) {
308            (Some(a), Some(b)) => a.checked_mul(b),
309            _ => None,
310        })
311        .flatten()
312        .ok_or(EvalError::MaxArraySizeExceeded(MAX_SIZE))?;
313
314    if matches!(
315        mz_repr::datum_size(&fill).checked_mul(fill_count),
316        None | Some(MAX_SIZE..)
317    ) {
318        return Err(EvalError::MaxArraySizeExceeded(MAX_SIZE));
319    }
320
321    let array_dimensions = if fill_count == 0 {
322        vec![ArrayDimension {
323            lower_bound: 1,
324            length: 0,
325        }]
326    } else {
327        dimensions
328            .into_iter()
329            .zip_eq(lower_bounds)
330            .map(|(length, lower_bound)| ArrayDimension {
331                lower_bound,
332                length,
333            })
334            .collect()
335    };
336
337    Ok(temp_storage.try_make_datum(|packer| {
338        packer.try_push_array(&array_dimensions, vec![fill; fill_count])
339    })?)
340}
341
342#[derive(
343    Ord,
344    PartialOrd,
345    Clone,
346    Debug,
347    Eq,
348    PartialEq,
349    Serialize,
350    Deserialize,
351    Hash,
352    MzReflect
353)]
354pub struct ArrayIndex {
355    pub offset: i64,
356}
357#[sqlfunc(ArrayIndex, sqlname = "array_index", introduces_nulls = true)]
358fn array_index<'a, T: FromDatum<'a>>(
359    &self,
360    array: Array<'a, T>,
361    indices: Variadic<i64>,
362) -> Option<T> {
363    mz_ore::soft_assert_no_log!(
364        self.offset == 0 || self.offset == 1,
365        "offset must be either 0 or 1"
366    );
367
368    let dims = array.dims();
369    if dims.len() != indices.len() {
370        // You missed the datums "layer"
371        return None;
372    }
373
374    let mut final_idx = 0;
375
376    for (d, idx) in dims.into_iter().zip_eq(indices.iter()) {
377        // Lower bound is written in terms of 1-based indexing, which offset accounts for.
378        let idx = isize::cast_from(*idx + self.offset);
379
380        let (lower, upper) = d.dimension_bounds();
381
382        // This index missed all of the data at this layer. The dimension bounds are inclusive,
383        // while range checks are exclusive, so adjust.
384        if !(lower..upper + 1).contains(&idx) {
385            return None;
386        }
387
388        // We discover how many indices our last index represents physically.
389        final_idx *= d.length;
390
391        // Because both index and lower bound are handled in 1-based indexing, taking their
392        // difference moves us back into 0-based indexing. Similarly, if the lower bound is
393        // negative, subtracting a negative value >= to itself ensures its non-negativity.
394        final_idx += usize::try_from(idx - d.lower_bound)
395            .expect("previous bounds check ensures physical index is at least 0");
396    }
397
398    array.elements().typed_iter().nth(final_idx)
399}
400
401#[sqlfunc]
402fn array_position<'a>(
403    array: Array<'a>,
404    search: Datum<'a>,
405    initial_pos: OptionalArg<Option<i32>>,
406) -> Result<Option<i32>, EvalError> {
407    if array.dims().len() > 1 {
408        return Err(EvalError::MultiDimensionalArraySearch);
409    }
410
411    if search == Datum::Null {
412        return Ok(None);
413    }
414
415    let skip = match initial_pos.0 {
416        None => 0,
417        Some(None) => return Err(EvalError::MustNotBeNull("initial position".into())),
418        Some(Some(o)) => usize::try_from(o).unwrap_or(0).saturating_sub(1),
419    };
420
421    let Some(r) = array.elements().iter().skip(skip).position(|d| d == search) else {
422        return Ok(None);
423    };
424
425    // Adjust count for the amount we skipped, plus 1 for adjusting to PG indexing scheme.
426    let p = i32::try_from(r + skip + 1).expect("fewer than i32::MAX elements in array");
427    Ok(Some(p))
428}
429
430#[derive(
431    Ord,
432    PartialOrd,
433    Clone,
434    Debug,
435    Eq,
436    PartialEq,
437    Serialize,
438    Deserialize,
439    Hash,
440    MzReflect
441)]
442pub struct ArrayToString {
443    pub elem_type: SqlScalarType,
444}
445
446#[sqlfunc]
447fn array_to_string<'a>(
448    &self,
449    array: Array<'a>,
450    delimiter: &str,
451    null_str_arg: OptionalArg<Option<&str>>,
452) -> Result<String, EvalError> {
453    // `flatten` treats absent arguments (`None`) the same as explicit NULL
454    // (`Some(None)`), both becoming `None`.
455    let null_str = null_str_arg.flatten();
456    let mut out = String::new();
457    for elem in array.elements().iter() {
458        if elem.is_null() {
459            if let Some(null_str) = null_str {
460                out.push_str(null_str);
461                out.push_str(delimiter);
462            }
463        } else {
464            stringify_datum(&mut out, elem, &self.elem_type)?;
465            out.push_str(delimiter);
466        }
467    }
468    if out.len() > 0 {
469        // Lop off last delimiter only if string is not empty
470        out.truncate(out.len() - delimiter.len());
471    }
472    Ok(out)
473}
474
475#[derive(
476    Ord,
477    PartialOrd,
478    Clone,
479    Debug,
480    Eq,
481    PartialEq,
482    Serialize,
483    Deserialize,
484    Hash,
485    MzReflect
486)]
487pub struct Coalesce;
488
489impl fmt::Display for Coalesce {
490    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
491        f.write_str("coalesce")
492    }
493}
494
495impl LazyVariadicFunc for Coalesce {
496    fn eval<'a>(
497        &'a self,
498        datums: &[Datum<'a>],
499        temp_storage: &'a RowArena,
500        exprs: &'a [MirScalarExpr],
501    ) -> Result<Datum<'a>, EvalError> {
502        for e in exprs {
503            let d = e.eval(datums, temp_storage)?;
504            if !d.is_null() {
505                return Ok(d);
506            }
507        }
508        Ok(Datum::Null)
509    }
510
511    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
512        // Note that the parser doesn't allow empty argument lists for variadic functions
513        // that use the standard function call syntax (ArrayCreate and co. are different
514        // because of the special syntax for calling them).
515        let nullable = input_types.iter().all(|typ| typ.nullable);
516        SqlColumnType::union_many(input_types).nullable(nullable)
517    }
518
519    fn propagates_nulls(&self) -> bool {
520        false
521    }
522
523    fn introduces_nulls(&self) -> bool {
524        false
525    }
526
527    fn could_error(&self) -> bool {
528        false
529    }
530
531    fn is_monotone(&self) -> bool {
532        true
533    }
534
535    fn is_associative(&self) -> bool {
536        true
537    }
538}
539
540#[derive(
541    Ord,
542    PartialOrd,
543    Clone,
544    Debug,
545    Eq,
546    PartialEq,
547    Serialize,
548    Deserialize,
549    Hash,
550    MzReflect
551)]
552pub struct RangeCreate {
553    pub elem_type: SqlScalarType,
554}
555
556impl fmt::Display for RangeCreate {
557    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
558        f.write_str(match &self.elem_type {
559            SqlScalarType::Int32 => "int4range",
560            SqlScalarType::Int64 => "int8range",
561            SqlScalarType::Date => "daterange",
562            SqlScalarType::Numeric { .. } => "numrange",
563            SqlScalarType::Timestamp { .. } => "tsrange",
564            SqlScalarType::TimestampTz { .. } => "tstzrange",
565            _ => unreachable!(),
566        })
567    }
568}
569
570impl EagerVariadicFunc for RangeCreate {
571    type Input<'a> = (Datum<'a>, Datum<'a>, Datum<'a>);
572    type Output<'a> = Result<Datum<'a>, EvalError>;
573
574    fn call<'a>(
575        &self,
576        (lower, upper, flags_datum): Self::Input<'a>,
577        temp_storage: &'a RowArena,
578    ) -> Self::Output<'a> {
579        let flags = match flags_datum {
580            Datum::Null => {
581                return Err(EvalError::InvalidRange(
582                    InvalidRangeError::NullRangeBoundFlags,
583                ));
584            }
585            o => o.unwrap_str(),
586        };
587
588        let (lower_inclusive, upper_inclusive) = parse_range_bound_flags(flags)?;
589
590        let mut range = Range::new(Some((
591            RangeBound::new(lower, lower_inclusive),
592            RangeBound::new(upper, upper_inclusive),
593        )));
594
595        range.canonicalize()?;
596
597        Ok(temp_storage.make_datum(|row| {
598            row.push_range(range).expect("errors already handled");
599        }))
600    }
601
602    fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
603        SqlScalarType::Range {
604            element_type: Box::new(self.elem_type.clone()),
605        }
606        .nullable(false)
607    }
608
609    fn introduces_nulls(&self) -> bool {
610        false
611    }
612}
613
614#[sqlfunc(sqlname = "datediff")]
615fn date_diff_date(unit_str: &str, a: Date, b: Date) -> Result<i64, EvalError> {
616    let unit = unit_str
617        .parse()
618        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
619
620    // Convert the Date into a timestamp so we can calculate age.
621    let a_ts = CheckedTimestamp::try_from(NaiveDate::from(a).and_hms_opt(0, 0, 0).unwrap())?;
622    let b_ts = CheckedTimestamp::try_from(NaiveDate::from(b).and_hms_opt(0, 0, 0).unwrap())?;
623    let diff = b_ts.diff_as(&a_ts, unit)?;
624    Ok(diff)
625}
626
627#[sqlfunc(sqlname = "datediff")]
628fn date_diff_time(unit_str: &str, a: NaiveTime, b: NaiveTime) -> Result<i64, EvalError> {
629    let unit = unit_str
630        .parse()
631        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
632
633    // Convert the Time into a timestamp so we can calculate age.
634    let a_ts =
635        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(a))?;
636    let b_ts =
637        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(b))?;
638    let diff = b_ts.diff_as(&a_ts, unit)?;
639    Ok(diff)
640}
641
642#[sqlfunc(sqlname = "datediff")]
643fn date_diff_timestamp(
644    unit: &str,
645    a: CheckedTimestamp<NaiveDateTime>,
646    b: CheckedTimestamp<NaiveDateTime>,
647) -> Result<i64, EvalError> {
648    let unit = unit
649        .parse()
650        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
651
652    let diff = b.diff_as(&a, unit)?;
653    Ok(diff)
654}
655
656#[sqlfunc(sqlname = "datediff")]
657fn date_diff_timestamp_tz(
658    unit: &str,
659    a: CheckedTimestamp<DateTime<Utc>>,
660    b: CheckedTimestamp<DateTime<Utc>>,
661) -> Result<i64, EvalError> {
662    let unit = unit
663        .parse()
664        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
665
666    let diff = b.diff_as(&a, unit)?;
667    Ok(diff)
668}
669
670#[derive(
671    Ord,
672    PartialOrd,
673    Clone,
674    Debug,
675    Eq,
676    PartialEq,
677    Serialize,
678    Deserialize,
679    Hash,
680    MzReflect
681)]
682pub struct ErrorIfNull;
683
684impl fmt::Display for ErrorIfNull {
685    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
686        f.write_str("error_if_null")
687    }
688}
689
690impl LazyVariadicFunc for ErrorIfNull {
691    fn eval<'a>(
692        &'a self,
693        datums: &[Datum<'a>],
694        temp_storage: &'a RowArena,
695        exprs: &'a [MirScalarExpr],
696    ) -> Result<Datum<'a>, EvalError> {
697        let first = exprs[0].eval(datums, temp_storage)?;
698        match first {
699            Datum::Null => {
700                let err_msg = match exprs[1].eval(datums, temp_storage)? {
701                    Datum::Null => {
702                        return Err(EvalError::Internal(
703                            "unexpected NULL in error side of error_if_null".into(),
704                        ));
705                    }
706                    o => o.unwrap_str(),
707                };
708                Err(EvalError::IfNullError(err_msg.into()))
709            }
710            _ => Ok(first),
711        }
712    }
713
714    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
715        input_types[0].scalar_type.clone().nullable(false)
716    }
717
718    fn propagates_nulls(&self) -> bool {
719        false
720    }
721
722    fn introduces_nulls(&self) -> bool {
723        false
724    }
725}
726
727#[derive(
728    Ord,
729    PartialOrd,
730    Clone,
731    Debug,
732    Eq,
733    PartialEq,
734    Serialize,
735    Deserialize,
736    Hash,
737    MzReflect
738)]
739pub struct Greatest;
740
741impl fmt::Display for Greatest {
742    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
743        f.write_str("greatest")
744    }
745}
746
747impl LazyVariadicFunc for Greatest {
748    fn eval<'a>(
749        &'a self,
750        datums: &[Datum<'a>],
751        temp_storage: &'a RowArena,
752        exprs: &'a [MirScalarExpr],
753    ) -> Result<Datum<'a>, EvalError> {
754        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
755        Ok(datums
756            .filter(|d| Ok(!d.is_null()))
757            .max()?
758            .unwrap_or(Datum::Null))
759    }
760
761    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
762        SqlColumnType::union_many(input_types)
763    }
764
765    fn propagates_nulls(&self) -> bool {
766        false
767    }
768
769    fn introduces_nulls(&self) -> bool {
770        false
771    }
772
773    fn could_error(&self) -> bool {
774        false
775    }
776
777    fn is_monotone(&self) -> bool {
778        true
779    }
780
781    fn is_associative(&self) -> bool {
782        true
783    }
784}
785
786#[sqlfunc(sqlname = "hmac")]
787fn hmac_string(to_digest: &str, key: &str, typ: &str) -> Result<Vec<u8>, EvalError> {
788    let to_digest = to_digest.as_bytes();
789    let key = key.as_bytes();
790    hmac_inner(to_digest, key, typ)
791}
792
793#[sqlfunc(sqlname = "hmac")]
794fn hmac_bytes(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
795    hmac_inner(to_digest, key, typ)
796}
797
798pub fn hmac_inner(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
799    match typ {
800        "md5" => {
801            let mut mac = Hmac::<Md5>::new_from_slice(key).expect("HMAC accepts any key size");
802            mac.update(to_digest);
803            Ok(mac.finalize().into_bytes().to_vec())
804        }
805        "sha1" => {
806            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key);
807            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
808        }
809        "sha224" => {
810            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA224, key);
811            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
812        }
813        "sha256" => {
814            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA256, key);
815            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
816        }
817        "sha384" => {
818            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA384, key);
819            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
820        }
821        "sha512" => {
822            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA512, key);
823            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
824        }
825        other => Err(EvalError::InvalidHashAlgorithm(other.into())),
826    }
827}
828
829#[sqlfunc]
830fn jsonb_build_array<'a>(datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> JsonbRef<'a> {
831    let datum = temp_storage.make_datum(|packer| {
832        packer.push_list(datums.into_iter().map(|d| match d {
833            Datum::Null => Datum::JsonNull,
834            d => d,
835        }))
836    });
837    JsonbRef::from_datum(datum)
838}
839
840#[sqlfunc]
841fn jsonb_build_object<'a>(
842    mut kvs: Variadic<(Datum<'a>, Datum<'a>)>,
843    temp_storage: &'a RowArena,
844) -> Result<JsonbRef<'a>, EvalError> {
845    kvs.0.sort_by(|kv1, kv2| kv1.0.cmp(&kv2.0));
846    kvs.0.dedup_by(|kv1, kv2| kv1.0 == kv2.0);
847    let datum = temp_storage.try_make_datum(|packer| {
848        packer.push_dict_with(|packer| {
849            for (k, v) in kvs {
850                if k.is_null() {
851                    return Err(EvalError::KeyCannotBeNull);
852                }
853                let v = match v {
854                    Datum::Null => Datum::JsonNull,
855                    d => d,
856                };
857                packer.push(k);
858                packer.push(v);
859            }
860            Ok(())
861        })
862    })?;
863    Ok(JsonbRef::from_datum(datum))
864}
865
866#[derive(
867    Ord,
868    PartialOrd,
869    Clone,
870    Debug,
871    Eq,
872    PartialEq,
873    Serialize,
874    Deserialize,
875    Hash,
876    MzReflect
877)]
878pub struct Least;
879
880impl fmt::Display for Least {
881    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
882        f.write_str("least")
883    }
884}
885
886impl LazyVariadicFunc for Least {
887    fn eval<'a>(
888        &'a self,
889        datums: &[Datum<'a>],
890        temp_storage: &'a RowArena,
891        exprs: &'a [MirScalarExpr],
892    ) -> Result<Datum<'a>, EvalError> {
893        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
894        Ok(datums
895            .filter(|d| Ok(!d.is_null()))
896            .min()?
897            .unwrap_or(Datum::Null))
898    }
899
900    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
901        SqlColumnType::union_many(input_types)
902    }
903
904    fn propagates_nulls(&self) -> bool {
905        false
906    }
907
908    fn introduces_nulls(&self) -> bool {
909        false
910    }
911
912    fn could_error(&self) -> bool {
913        false
914    }
915
916    fn is_monotone(&self) -> bool {
917        true
918    }
919
920    fn is_associative(&self) -> bool {
921        true
922    }
923}
924
925#[derive(
926    Ord,
927    PartialOrd,
928    Clone,
929    Debug,
930    Eq,
931    PartialEq,
932    Serialize,
933    Deserialize,
934    Hash,
935    MzReflect
936)]
937pub struct ListCreate {
938    pub elem_type: SqlScalarType,
939}
940
941#[sqlfunc(
942    output_type_expr = "SqlScalarType::List { element_type: Box::new(self.elem_type.clone()), custom_id: None }.nullable(false)",
943    introduces_nulls = false
944)]
945fn list_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
946    temp_storage.make_datum(|packer| packer.push_list(datums))
947}
948
949#[derive(
950    Ord,
951    PartialOrd,
952    Clone,
953    Debug,
954    Eq,
955    PartialEq,
956    Serialize,
957    Deserialize,
958    Hash,
959    MzReflect
960)]
961pub struct RecordCreate {
962    pub field_names: Vec<ColumnName>,
963}
964
965#[sqlfunc(
966    output_type_expr = "SqlScalarType::Record { fields: self.field_names.clone().into_iter().zip_eq(input_types.iter().cloned()).collect(), custom_id: None }.nullable(false)",
967    introduces_nulls = false
968)]
969fn record_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
970    temp_storage.make_datum(|packer| packer.push_list(datums.iter().copied()))
971}
972
973#[sqlfunc(
974    output_type_expr = "input_types[0].scalar_type.unwrap_list_nth_layer_type(input_types.len() - 1).clone().nullable(true)",
975    introduces_nulls = true
976)]
977// TODO(benesch): remove potentially dangerous usage of `as`.
978#[allow(clippy::as_conversions)]
979fn list_index<'a>(buf: DatumList<'a>, indices: Variadic<i64>) -> Datum<'a> {
980    let mut buf = Datum::List(buf);
981    for i in indices {
982        if buf.is_null() {
983            break;
984        }
985        if i < 1 {
986            return Datum::Null;
987        }
988
989        buf = match buf.unwrap_list().iter().nth(i as usize - 1) {
990            Some(datum) => datum,
991            None => return Datum::Null,
992        }
993    }
994    buf
995}
996
997#[sqlfunc(sqlname = "makeaclitem")]
998fn make_acl_item(
999    grantee_oid: u32,
1000    grantor_oid: u32,
1001    privileges: &str,
1002    is_grantable: bool,
1003) -> Result<AclItem, EvalError> {
1004    let grantee = Oid(grantee_oid);
1005    let grantor = Oid(grantor_oid);
1006    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1007        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1008    if is_grantable {
1009        return Err(EvalError::Unsupported {
1010            feature: "GRANT OPTION".into(),
1011            discussion_no: None,
1012        });
1013    }
1014
1015    Ok(AclItem {
1016        grantee,
1017        grantor,
1018        acl_mode,
1019    })
1020}
1021
1022#[sqlfunc(sqlname = "make_mz_aclitem")]
1023fn make_mz_acl_item(
1024    grantee_str: &str,
1025    grantor_str: &str,
1026    privileges: &str,
1027) -> Result<MzAclItem, EvalError> {
1028    let grantee: RoleId = grantee_str
1029        .parse()
1030        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1031    let grantor: RoleId = grantor_str
1032        .parse()
1033        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1034    if grantor == RoleId::Public {
1035        return Err(EvalError::InvalidRoleId(
1036            "mz_aclitem grantor cannot be PUBLIC role".into(),
1037        ));
1038    }
1039    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1040        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1041
1042    Ok(MzAclItem {
1043        grantee,
1044        grantor,
1045        acl_mode,
1046    })
1047}
1048
1049#[sqlfunc(sqlname = "makets")]
1050// TODO(benesch): remove potentially dangerous usage of `as`.
1051#[allow(clippy::as_conversions)]
1052fn make_timestamp(
1053    year: i64,
1054    month: i64,
1055    day: i64,
1056    hour: i64,
1057    minute: i64,
1058    second_float: f64,
1059) -> Result<Option<CheckedTimestamp<NaiveDateTime>>, EvalError> {
1060    let year: i32 = match year.try_into() {
1061        Ok(year) => year,
1062        Err(_) => return Ok(None),
1063    };
1064    let month: u32 = match month.try_into() {
1065        Ok(month) => month,
1066        Err(_) => return Ok(None),
1067    };
1068    let day: u32 = match day.try_into() {
1069        Ok(day) => day,
1070        Err(_) => return Ok(None),
1071    };
1072    let hour: u32 = match hour.try_into() {
1073        Ok(hour) => hour,
1074        Err(_) => return Ok(None),
1075    };
1076    let minute: u32 = match minute.try_into() {
1077        Ok(minute) => minute,
1078        Err(_) => return Ok(None),
1079    };
1080    let second = second_float as u32;
1081    let micros = ((second_float - second as f64) * 1_000_000.0) as u32;
1082    let date = match NaiveDate::from_ymd_opt(year, month, day) {
1083        Some(date) => date,
1084        None => return Ok(None),
1085    };
1086    let timestamp = match date.and_hms_micro_opt(hour, minute, second, micros) {
1087        Some(timestamp) => timestamp,
1088        None => return Ok(None),
1089    };
1090    Ok(Some(timestamp.try_into()?))
1091}
1092
1093#[derive(
1094    Ord,
1095    PartialOrd,
1096    Clone,
1097    Debug,
1098    Eq,
1099    PartialEq,
1100    Serialize,
1101    Deserialize,
1102    Hash,
1103    MzReflect
1104)]
1105pub struct MapBuild {
1106    pub value_type: SqlScalarType,
1107}
1108
1109#[sqlfunc(
1110    output_type_expr = "SqlScalarType::Map { value_type: Box::new(self.value_type.clone()), custom_id: None }.nullable(false)",
1111    introduces_nulls = false
1112)]
1113fn map_build<'a>(
1114    &self,
1115    datums: Variadic<(Option<&str>, Datum<'a>)>,
1116    temp_storage: &'a RowArena,
1117) -> Datum<'a> {
1118    // Collect into a `BTreeMap` to provide the same semantics as it.
1119    let map: std::collections::BTreeMap<&str, _> = datums
1120        .into_iter()
1121        .filter_map(|(k, v)| k.map(|k| (k, v)))
1122        .collect();
1123
1124    temp_storage.make_datum(|packer| packer.push_dict(map))
1125}
1126
1127#[derive(
1128    Ord,
1129    PartialOrd,
1130    Clone,
1131    Debug,
1132    Eq,
1133    PartialEq,
1134    Serialize,
1135    Deserialize,
1136    Hash,
1137    MzReflect
1138)]
1139pub struct Or;
1140
1141impl fmt::Display for Or {
1142    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1143        f.write_str("OR")
1144    }
1145}
1146
1147impl LazyVariadicFunc for Or {
1148    fn eval<'a>(
1149        &'a self,
1150        datums: &[Datum<'a>],
1151        temp_storage: &'a RowArena,
1152        exprs: &'a [MirScalarExpr],
1153    ) -> Result<Datum<'a>, EvalError> {
1154        // If any is true, then return true. Else, if any is null, then return null. Else, return false.
1155        let mut null = false;
1156        let mut err = None;
1157        for expr in exprs {
1158            match expr.eval(datums, temp_storage) {
1159                Ok(Datum::False) => {}
1160                Ok(Datum::True) => return Ok(Datum::True), // short-circuit
1161                // No return in these two cases, because we might still see a true
1162                Ok(Datum::Null) => null = true,
1163                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
1164                _ => unreachable!(),
1165            }
1166        }
1167        match (err, null) {
1168            (Some(err), _) => Err(err),
1169            (None, true) => Ok(Datum::Null),
1170            (None, false) => Ok(Datum::False),
1171        }
1172    }
1173
1174    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1175        let in_nullable = input_types.iter().any(|t| t.nullable);
1176        SqlScalarType::Bool.nullable(in_nullable)
1177    }
1178
1179    fn propagates_nulls(&self) -> bool {
1180        false
1181    }
1182
1183    fn introduces_nulls(&self) -> bool {
1184        false
1185    }
1186
1187    fn could_error(&self) -> bool {
1188        false
1189    }
1190
1191    fn is_monotone(&self) -> bool {
1192        true
1193    }
1194
1195    fn is_associative(&self) -> bool {
1196        true
1197    }
1198
1199    fn is_infix_op(&self) -> bool {
1200        true
1201    }
1202}
1203
1204#[sqlfunc(sqlname = "lpad")]
1205fn pad_leading(string: &str, raw_len: i32, pad: OptionalArg<&str>) -> Result<String, EvalError> {
1206    let len = match usize::try_from(raw_len) {
1207        Ok(len) => len,
1208        Err(_) => {
1209            return Err(EvalError::InvalidParameterValue(
1210                "length must be nonnegative".into(),
1211            ));
1212        }
1213    };
1214    if len > MAX_STRING_FUNC_RESULT_BYTES {
1215        return Err(EvalError::LengthTooLarge);
1216    }
1217
1218    let pad_string = pad.unwrap_or(" ");
1219
1220    let (end_char, end_char_byte_offset) = string
1221        .chars()
1222        .take(len)
1223        .fold((0, 0), |acc, char| (acc.0 + 1, acc.1 + char.len_utf8()));
1224
1225    let mut buf = String::with_capacity(len);
1226    if len == end_char {
1227        buf.push_str(&string[0..end_char_byte_offset]);
1228    } else {
1229        buf.extend(pad_string.chars().cycle().take(len - end_char));
1230        buf.push_str(string);
1231    }
1232
1233    Ok(buf)
1234}
1235
1236#[sqlfunc(
1237    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true)",
1238    introduces_nulls = true
1239)]
1240fn regexp_match<'a>(
1241    haystack: &'a str,
1242    needle: &str,
1243    flags: OptionalArg<&str>,
1244    temp_storage: &'a RowArena,
1245) -> Result<Datum<'a>, EvalError> {
1246    let flags = flags.unwrap_or("");
1247    let needle = build_regex(needle, flags)?;
1248    regexp_match_static(Datum::String(haystack), temp_storage, &needle)
1249}
1250
1251#[sqlfunc(
1252    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1253    introduces_nulls = false
1254)]
1255fn regexp_split_to_array<'a>(
1256    text: &str,
1257    regexp_str: &str,
1258    flags: OptionalArg<&str>,
1259    temp_storage: &'a RowArena,
1260) -> Result<Datum<'a>, EvalError> {
1261    let flags = flags.unwrap_or("");
1262    let regexp = build_regex(regexp_str, flags)?;
1263    regexp_split_to_array_re(text, &regexp, temp_storage)
1264}
1265
1266#[sqlfunc]
1267fn regexp_replace<'a>(
1268    source: &'a str,
1269    pattern: &str,
1270    replacement: &str,
1271    flags_opt: OptionalArg<&str>,
1272) -> Result<Cow<'a, str>, EvalError> {
1273    let flags = flags_opt.0.unwrap_or("");
1274    let (limit, flags) = regexp_replace_parse_flags(flags);
1275    let regexp = build_regex(pattern, &flags)?;
1276    Ok(regexp.replacen(source, limit, replacement))
1277}
1278
1279#[sqlfunc]
1280fn replace(text: &str, from: &str, to: &str) -> Result<String, EvalError> {
1281    // As a compromise to avoid always nearly duplicating the work of replace by doing size estimation,
1282    // we first check if it's possible for the fully replaced string to exceed the limit by assuming that
1283    // every possible substring is replaced.
1284    //
1285    // If that estimate exceeds the limit, we then do a more precise (and expensive) estimate by counting
1286    // the actual number of replacements that would occur, and using that to calculate the final size.
1287    let possible_size = text.len() * to.len();
1288    if possible_size > MAX_STRING_FUNC_RESULT_BYTES {
1289        let replacement_count = text.matches(from).count();
1290        let estimated_size = text.len() + replacement_count * (to.len().saturating_sub(from.len()));
1291        if estimated_size > MAX_STRING_FUNC_RESULT_BYTES {
1292            return Err(EvalError::LengthTooLarge);
1293        }
1294    }
1295
1296    Ok(text.replace(from, to))
1297}
1298
1299#[sqlfunc(
1300    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1301    introduces_nulls = false,
1302    propagates_nulls = false
1303)]
1304fn string_to_array<'a>(
1305    string: &'a str,
1306    delimiter: Option<&'a str>,
1307    null_string: OptionalArg<Option<&'a str>>,
1308    temp_storage: &'a RowArena,
1309) -> Result<Datum<'a>, EvalError> {
1310    if string.is_empty() {
1311        let mut row = Row::default();
1312        let mut packer = row.packer();
1313        packer.try_push_array(&[], std::iter::empty::<Datum>())?;
1314
1315        return Ok(temp_storage.push_unary_row(row));
1316    }
1317
1318    let Some(delimiter) = delimiter else {
1319        let split_all_chars_delimiter = "";
1320        return string_to_array_impl(
1321            string,
1322            split_all_chars_delimiter,
1323            null_string.flatten(),
1324            temp_storage,
1325        );
1326    };
1327
1328    if delimiter.is_empty() {
1329        let mut row = Row::default();
1330        let mut packer = row.packer();
1331        let dims = &[ArrayDimension {
1332            lower_bound: 1,
1333            length: 1,
1334        }];
1335        match null_string.flatten() {
1336            Some(null_string) if null_string == string => {
1337                packer.try_push_array(dims, std::iter::once(Datum::Null))?;
1338            }
1339            _ => {
1340                packer.try_push_array(dims, vec![string].into_iter().map(Datum::String))?;
1341            }
1342        }
1343        Ok(temp_storage.push_unary_row(row))
1344    } else {
1345        string_to_array_impl(string, delimiter, null_string.flatten(), temp_storage)
1346    }
1347}
1348
1349fn string_to_array_impl<'a>(
1350    string: &str,
1351    delimiter: &str,
1352    null_string: Option<&'a str>,
1353    temp_storage: &'a RowArena,
1354) -> Result<Datum<'a>, EvalError> {
1355    let mut row = Row::default();
1356    let mut packer = row.packer();
1357
1358    let result = string.split(delimiter);
1359    let found: Vec<&str> = if delimiter.is_empty() {
1360        result.filter(|s| !s.is_empty()).collect()
1361    } else {
1362        result.collect()
1363    };
1364    let array_dimensions = [ArrayDimension {
1365        lower_bound: 1,
1366        length: found.len(),
1367    }];
1368
1369    if let Some(null_string) = null_string {
1370        let found_datums = found.into_iter().map(|chunk| {
1371            if chunk.eq(null_string) {
1372                Datum::Null
1373            } else {
1374                Datum::String(chunk)
1375            }
1376        });
1377
1378        packer.try_push_array(&array_dimensions, found_datums)?;
1379    } else {
1380        packer.try_push_array(&array_dimensions, found.into_iter().map(Datum::String))?;
1381    }
1382
1383    Ok(temp_storage.push_unary_row(row))
1384}
1385
1386#[sqlfunc]
1387fn substr<'a>(s: &'a str, start: i32, length: OptionalArg<i32>) -> Result<&'a str, EvalError> {
1388    let raw_start_idx = i64::from(start) - 1;
1389    let start_idx = match usize::try_from(cmp::max(raw_start_idx, 0)) {
1390        Ok(i) => i,
1391        Err(_) => {
1392            return Err(EvalError::InvalidParameterValue(
1393                format!(
1394                    "substring starting index ({}) exceeds min/max position",
1395                    raw_start_idx
1396                )
1397                .into(),
1398            ));
1399        }
1400    };
1401
1402    let mut char_indices = s.char_indices();
1403    let get_str_index = |(index, _char)| index;
1404
1405    let str_len = s.len();
1406    let start_char_idx = char_indices.nth(start_idx).map_or(str_len, get_str_index);
1407
1408    if let OptionalArg(Some(len)) = length {
1409        let end_idx = match i64::from(len) {
1410            e if e < 0 => {
1411                return Err(EvalError::InvalidParameterValue(
1412                    "negative substring length not allowed".into(),
1413                ));
1414            }
1415            e if e == 0 || e + raw_start_idx < 1 => return Ok(""),
1416            e => {
1417                let e = cmp::min(raw_start_idx + e - 1, e - 1);
1418                match usize::try_from(e) {
1419                    Ok(i) => i,
1420                    Err(_) => {
1421                        return Err(EvalError::InvalidParameterValue(
1422                            format!("substring length ({}) exceeds max position", e).into(),
1423                        ));
1424                    }
1425                }
1426            }
1427        };
1428
1429        let end_char_idx = char_indices.nth(end_idx).map_or(str_len, get_str_index);
1430
1431        Ok(&s[start_char_idx..end_char_idx])
1432    } else {
1433        Ok(&s[start_char_idx..])
1434    }
1435}
1436
1437#[sqlfunc(sqlname = "split_string")]
1438fn split_part<'a>(string: &'a str, delimiter: &str, field: i32) -> Result<&'a str, EvalError> {
1439    let index = match usize::try_from(i64::from(field) - 1) {
1440        Ok(index) => index,
1441        Err(_) => {
1442            return Err(EvalError::InvalidParameterValue(
1443                "field position must be greater than zero".into(),
1444            ));
1445        }
1446    };
1447
1448    // If the provided delimiter is the empty string,
1449    // PostgreSQL does not break the string into individual
1450    // characters. Instead, it generates the following parts: [string].
1451    if delimiter.is_empty() {
1452        if index == 0 {
1453            return Ok(string);
1454        } else {
1455            return Ok("");
1456        }
1457    }
1458
1459    // If provided index is greater than the number of split parts,
1460    // return an empty string.
1461    Ok(string.split(delimiter).nth(index).unwrap_or(""))
1462}
1463
1464#[sqlfunc(is_associative = true)]
1465fn concat(strs: Variadic<Option<&str>>) -> Result<String, EvalError> {
1466    let mut total_size = 0;
1467    for s in &strs {
1468        if let Some(s) = s {
1469            total_size += s.len();
1470            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1471                return Err(EvalError::LengthTooLarge);
1472            }
1473        }
1474    }
1475    let mut buf = String::with_capacity(total_size);
1476    for s in strs {
1477        if let Some(s) = s {
1478            buf.push_str(s);
1479        }
1480    }
1481    Ok(buf)
1482}
1483
1484#[sqlfunc]
1485fn concat_ws(ws: &str, rest: Variadic<Option<&str>>) -> Result<String, EvalError> {
1486    let mut total_size = 0;
1487    for s in &rest {
1488        if let Some(s) = s {
1489            total_size += s.len();
1490            total_size += ws.len();
1491            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1492                return Err(EvalError::LengthTooLarge);
1493            }
1494        }
1495    }
1496
1497    let buf = Itertools::join(&mut rest.into_iter().filter_map(|s| s), ws);
1498
1499    Ok(buf)
1500}
1501
1502#[sqlfunc]
1503fn translate(string: &str, from_str: &str, to_str: &str) -> String {
1504    let from = from_str.chars().collect::<Vec<_>>();
1505    let to = to_str.chars().collect::<Vec<_>>();
1506
1507    string
1508        .chars()
1509        .filter_map(|c| match from.iter().position(|f| f == &c) {
1510            Some(idx) => to.get(idx).copied(),
1511            None => Some(c),
1512        })
1513        .collect()
1514}
1515
1516#[sqlfunc(
1517    output_type_expr = "input_types[0].scalar_type.clone().nullable(false)",
1518    introduces_nulls = false
1519)]
1520// TODO(benesch): remove potentially dangerous usage of `as`.
1521#[allow(clippy::as_conversions)]
1522fn list_slice_linear<'a>(
1523    list: DatumList<'a>,
1524    first: (i64, i64),
1525    remainder: Variadic<(i64, i64)>,
1526    temp_storage: &'a RowArena,
1527) -> Datum<'a> {
1528    let mut start_idx = 0;
1529    let mut total_length = usize::MAX;
1530
1531    for (start, end) in std::iter::once(first).chain(remainder) {
1532        let start = std::cmp::max(start, 1);
1533
1534        // Result should be empty list.
1535        if start > end {
1536            start_idx = 0;
1537            total_length = 0;
1538            break;
1539        }
1540
1541        let start_inner = start as usize - 1;
1542        // Start index only moves to geq positions.
1543        start_idx += start_inner;
1544
1545        // Length index only moves to leq positions
1546        let length_inner = (end - start) as usize + 1;
1547        total_length = std::cmp::min(length_inner, total_length - start_inner);
1548    }
1549
1550    let iter = list.iter().skip(start_idx).take(total_length);
1551
1552    temp_storage.make_datum(|row| {
1553        row.push_list_with(|row| {
1554            // if iter is empty, will get the appropriate empty list.
1555            for d in iter {
1556                row.push(d);
1557            }
1558        });
1559    })
1560}
1561
1562#[sqlfunc(sqlname = "timestamp_bin")]
1563fn date_bin_timestamp(
1564    stride: Interval,
1565    source: CheckedTimestamp<NaiveDateTime>,
1566    origin: CheckedTimestamp<NaiveDateTime>,
1567) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
1568    date_bin(stride, source, origin)
1569}
1570
1571#[sqlfunc(sqlname = "timestamptz_bin")]
1572fn date_bin_timestamp_tz(
1573    stride: Interval,
1574    source: CheckedTimestamp<DateTime<Utc>>,
1575    origin: CheckedTimestamp<DateTime<Utc>>,
1576) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
1577    date_bin(stride, source, origin)
1578}
1579
1580#[sqlfunc(sqlname = "timezonet")]
1581fn timezone_time_variadic(
1582    tz_str: &str,
1583    time: NaiveTime,
1584    wall_time: CheckedTimestamp<DateTime<Utc>>,
1585) -> Result<NaiveTime, EvalError> {
1586    parse_timezone(tz_str, TimezoneSpec::Posix)
1587        .map(|tz| timezone_time(tz, time, &wall_time.naive_utc()))
1588}
1589pub(crate) trait LazyVariadicFunc: fmt::Display {
1590    fn eval<'a>(
1591        &'a self,
1592        datums: &[Datum<'a>],
1593        temp_storage: &'a RowArena,
1594        exprs: &'a [MirScalarExpr],
1595    ) -> Result<Datum<'a>, EvalError>;
1596
1597    /// The output SqlColumnType of this function.
1598    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1599
1600    /// Whether this function will produce NULL on NULL input.
1601    fn propagates_nulls(&self) -> bool;
1602
1603    /// Whether this function will produce NULL on non-NULL input.
1604    fn introduces_nulls(&self) -> bool;
1605
1606    /// Whether this function might error on non-error input.
1607    fn could_error(&self) -> bool {
1608        true
1609    }
1610
1611    /// Returns true if the function is monotone.
1612    fn is_monotone(&self) -> bool {
1613        false
1614    }
1615
1616    /// Returns true if the function is associative.
1617    fn is_associative(&self) -> bool {
1618        false
1619    }
1620
1621    /// Returns true if the function is an infix operator.
1622    fn is_infix_op(&self) -> bool {
1623        false
1624    }
1625}
1626
1627pub(crate) trait EagerVariadicFunc: fmt::Display {
1628    type Input<'a>: InputDatumType<'a, EvalError>;
1629    type Output<'a>: OutputDatumType<'a, EvalError>;
1630
1631    fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
1632
1633    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1634
1635    fn propagates_nulls(&self) -> bool {
1636        !Self::Input::nullable()
1637    }
1638
1639    fn introduces_nulls(&self) -> bool {
1640        Self::Output::nullable()
1641    }
1642
1643    fn could_error(&self) -> bool {
1644        Self::Output::fallible()
1645    }
1646
1647    fn is_monotone(&self) -> bool {
1648        false
1649    }
1650
1651    fn is_associative(&self) -> bool {
1652        false
1653    }
1654
1655    fn is_infix_op(&self) -> bool {
1656        false
1657    }
1658}
1659
1660/// Blanket `LazyVariadicFunc` impl for each eager type, bridging
1661/// expression evaluation and null propagation via `InputDatumType::try_from_iter`.
1662impl<T: EagerVariadicFunc> LazyVariadicFunc for T {
1663    fn eval<'a>(
1664        &'a self,
1665        datums: &[Datum<'a>],
1666        temp_storage: &'a RowArena,
1667        exprs: &'a [MirScalarExpr],
1668    ) -> Result<Datum<'a>, EvalError> {
1669        let mut datums = exprs.iter().map(|e| e.eval(datums, temp_storage));
1670        match T::Input::try_from_iter(&mut datums) {
1671            Ok(input) => self.call(input, temp_storage).into_result(temp_storage),
1672            Err(Ok(None)) => Err(EvalError::Internal("missing parameter".into())),
1673            Err(Ok(Some(datum))) if datum.is_null() => Ok(datum),
1674            Err(Ok(Some(_datum))) => {
1675                // datum is _not_ NULL
1676                Err(EvalError::Internal("invalid input type".into()))
1677            }
1678            Err(Err(res)) => Err(res),
1679        }
1680    }
1681
1682    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1683        self.output_type(input_types)
1684    }
1685
1686    fn propagates_nulls(&self) -> bool {
1687        self.propagates_nulls()
1688    }
1689
1690    fn introduces_nulls(&self) -> bool {
1691        self.introduces_nulls()
1692    }
1693
1694    fn could_error(&self) -> bool {
1695        self.could_error()
1696    }
1697
1698    fn is_monotone(&self) -> bool {
1699        self.is_monotone()
1700    }
1701
1702    fn is_associative(&self) -> bool {
1703        self.is_associative()
1704    }
1705
1706    fn is_infix_op(&self) -> bool {
1707        self.is_infix_op()
1708    }
1709}
1710
1711derive_variadic! {
1712    Coalesce(Coalesce),
1713    Greatest(Greatest),
1714    Least(Least),
1715    Concat(Concat),
1716    ConcatWs(ConcatWs),
1717    MakeTimestamp(MakeTimestamp),
1718    PadLeading(PadLeading),
1719    Substr(Substr),
1720    Replace(Replace),
1721    JsonbBuildArray(JsonbBuildArray),
1722    JsonbBuildObject(JsonbBuildObject),
1723    MapBuild(MapBuild),
1724    ArrayCreate(ArrayCreate),
1725    ArrayToString(ArrayToString),
1726    ArrayIndex(ArrayIndex),
1727    ListCreate(ListCreate),
1728    RecordCreate(RecordCreate),
1729    ListIndex(ListIndex),
1730    ListSliceLinear(ListSliceLinear),
1731    SplitPart(SplitPart),
1732    RegexpMatch(RegexpMatch),
1733    HmacString(HmacString),
1734    HmacBytes(HmacBytes),
1735    ErrorIfNull(ErrorIfNull),
1736    DateBinTimestamp(DateBinTimestamp),
1737    DateBinTimestampTz(DateBinTimestampTz),
1738    DateDiffTimestamp(DateDiffTimestamp),
1739    DateDiffTimestampTz(DateDiffTimestampTz),
1740    DateDiffDate(DateDiffDate),
1741    DateDiffTime(DateDiffTime),
1742    And(And),
1743    Or(Or),
1744    RangeCreate(RangeCreate),
1745    MakeAclItem(MakeAclItem),
1746    MakeMzAclItem(MakeMzAclItem),
1747    Translate(Translate),
1748    ArrayPosition(ArrayPosition),
1749    ArrayFill(ArrayFill),
1750    StringToArray(StringToArray),
1751    TimezoneTimeVariadic(TimezoneTimeVariadic),
1752    RegexpSplitToArray(RegexpSplitToArray),
1753    RegexpReplace(RegexpReplace),
1754    CaseLiteral(CaseLiteral),
1755}
1756
1757impl VariadicFunc {
1758    pub fn switch_and_or(&self) -> Self {
1759        match self {
1760            VariadicFunc::And(_) => Or.into(),
1761            VariadicFunc::Or(_) => And.into(),
1762            _ => unreachable!(),
1763        }
1764    }
1765
1766    /// Gives the unit (u) of OR or AND, such that `u AND/OR x == x`.
1767    /// Note that a 0-arg AND/OR evaluates to unit_of_and_or.
1768    pub fn unit_of_and_or(&self) -> MirScalarExpr {
1769        match self {
1770            VariadicFunc::And(_) => MirScalarExpr::literal_true(),
1771            VariadicFunc::Or(_) => MirScalarExpr::literal_false(),
1772            _ => unreachable!(),
1773        }
1774    }
1775
1776    /// Gives the zero (z) of OR or AND, such that `z AND/OR x == z`.
1777    pub fn zero_of_and_or(&self) -> MirScalarExpr {
1778        match self {
1779            VariadicFunc::And(_) => MirScalarExpr::literal_false(),
1780            VariadicFunc::Or(_) => MirScalarExpr::literal_true(),
1781            _ => unreachable!(),
1782        }
1783    }
1784}