Skip to main content

mz_expr/scalar/func/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Variadic functions.
15
16use std::borrow::Cow;
17use std::cmp;
18use std::fmt;
19
20use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
21use fallible_iterator::FallibleIterator;
22use hmac::{Hmac, Mac};
23use itertools::Itertools;
24use md5::Md5;
25use mz_expr_derive::sqlfunc;
26use mz_lowertest::MzReflect;
27use mz_ore::cast::{CastFrom, ReinterpretCast};
28use mz_pgtz::timezone::TimezoneSpec;
29use mz_repr::ReprColumnType;
30use mz_repr::adt::array::{Array, ArrayDimension, ArrayDimensions, InvalidArrayError};
31use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem};
32use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags};
33use mz_repr::adt::system::Oid;
34use mz_repr::adt::timestamp::CheckedTimestamp;
35use mz_repr::role_id::RoleId;
36use mz_repr::{
37    ColumnName, Datum, DatumList, InputDatumType, OptionalArg, OutputDatumType, Row, RowArena,
38    SqlColumnType, SqlScalarType, Variadic,
39};
40use serde::{Deserialize, Serialize};
41use sha1::Sha1;
42use sha2::{Sha224, Sha256, Sha384, Sha512};
43
44use crate::func::{
45    MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin, parse_timezone,
46    regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re, stringify_datum,
47    timezone_time,
48};
49use crate::{EvalError, MirScalarExpr};
50use mz_repr::adt::date::Date;
51use mz_repr::adt::interval::Interval;
52use mz_repr::adt::jsonb::JsonbRef;
53
54#[derive(
55    Ord,
56    PartialOrd,
57    Clone,
58    Debug,
59    Eq,
60    PartialEq,
61    Serialize,
62    Deserialize,
63    Hash,
64    MzReflect
65)]
66pub struct And;
67
68impl fmt::Display for And {
69    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70        f.write_str("AND")
71    }
72}
73
74impl LazyVariadicFunc for And {
75    fn eval<'a>(
76        &'a self,
77        datums: &[Datum<'a>],
78        temp_storage: &'a RowArena,
79        exprs: &'a [MirScalarExpr],
80    ) -> Result<Datum<'a>, EvalError> {
81        // If any is false, then return false. Else, if any is null, then return null. Else, return true.
82        let mut null = false;
83        let mut err = None;
84        for expr in exprs {
85            match expr.eval(datums, temp_storage) {
86                Ok(Datum::False) => return Ok(Datum::False), // short-circuit
87                Ok(Datum::True) => {}
88                // No return in these two cases, because we might still see a false
89                Ok(Datum::Null) => null = true,
90                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
91                _ => unreachable!(),
92            }
93        }
94        match (err, null) {
95            (Some(err), _) => Err(err),
96            (None, true) => Ok(Datum::Null),
97            (None, false) => Ok(Datum::True),
98        }
99    }
100
101    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
102        let in_nullable = input_types.iter().any(|t| t.nullable);
103        SqlScalarType::Bool.nullable(in_nullable)
104    }
105
106    fn propagates_nulls(&self) -> bool {
107        false
108    }
109
110    fn introduces_nulls(&self) -> bool {
111        false
112    }
113
114    fn could_error(&self) -> bool {
115        false
116    }
117
118    fn is_monotone(&self) -> bool {
119        true
120    }
121
122    fn is_associative(&self) -> bool {
123        true
124    }
125
126    fn is_infix_op(&self) -> bool {
127        true
128    }
129}
130
131#[derive(
132    Ord,
133    PartialOrd,
134    Clone,
135    Debug,
136    Eq,
137    PartialEq,
138    Serialize,
139    Deserialize,
140    Hash,
141    MzReflect
142)]
143pub struct ArrayCreate {
144    pub elem_type: SqlScalarType,
145}
146
147/// Constructs a new multidimensional array out of an arbitrary number of
148/// lower-dimensional arrays.
149///
150/// For example, if given three 1D arrays of length 2, this function will
151/// construct a 2D array with dimensions 3x2.
152///
153/// The input datums in `datums` must all be arrays of the same dimensions.
154/// (The arrays must also be of the same element type, but that is checked by
155/// the SQL type system, rather than checked here at runtime.)
156///
157/// If all input arrays are zero-dimensional arrays, then the output is a zero-
158/// dimensional array. Otherwise, the lower bound of the additional dimension is
159/// one and the length of the new dimension is equal to `datums.len()`.
160///
161/// Null elements are allowed and considered to be zero-dimensional arrays.
162#[sqlfunc(
163    ArrayCreate,
164    output_type_expr = "match &self.elem_type { SqlScalarType::Array(_) => self.elem_type.clone().nullable(false), _ => SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false) }",
165    introduces_nulls = false
166)]
167fn array_create<'a>(
168    &self,
169    datums: Variadic<Datum<'a>>,
170    temp_storage: &'a RowArena,
171) -> Result<Datum<'a>, EvalError> {
172    match &self.elem_type {
173        SqlScalarType::Array(_) => array_create_multidim(&datums, temp_storage),
174        _ => array_create_scalar(&datums, temp_storage),
175    }
176}
177fn array_create_multidim<'a>(
178    datums: &[Datum<'a>],
179    temp_storage: &'a RowArena,
180) -> Result<Datum<'a>, EvalError> {
181    let mut dim: Option<ArrayDimensions> = None;
182    for datum in datums {
183        let actual_dims = match datum {
184            Datum::Null => ArrayDimensions::default(),
185            Datum::Array(arr) => arr.dims(),
186            d => panic!("unexpected datum {d}"),
187        };
188        if let Some(expected) = &dim {
189            if actual_dims.ndims() != expected.ndims() {
190                let actual = actual_dims.ndims().into();
191                let expected = expected.ndims().into();
192                // All input arrays must have the same dimensionality.
193                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
194            }
195            if let Some((e, a)) = expected
196                .into_iter()
197                .zip_eq(actual_dims.into_iter())
198                .find(|(e, a)| e != a)
199            {
200                let actual = a.length;
201                let expected = e.length;
202                // All input arrays must have the same dimensionality.
203                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
204            }
205        }
206        dim = Some(actual_dims);
207    }
208    // Per PostgreSQL, if all input arrays are zero dimensional, so is the output.
209    if dim.as_ref().map_or(true, ArrayDimensions::is_empty) {
210        return Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&[], &[]))?);
211    }
212
213    let mut dims = vec![ArrayDimension {
214        lower_bound: 1,
215        length: datums.len(),
216    }];
217    if let Some(d) = datums.first() {
218        dims.extend(d.unwrap_array().dims());
219    };
220    let elements = datums
221        .iter()
222        .flat_map(|d| d.unwrap_array().elements().iter());
223    let datum =
224        temp_storage.try_make_datum(move |packer| packer.try_push_array(&dims, elements))?;
225    Ok(datum)
226}
227
228#[derive(
229    Ord,
230    PartialOrd,
231    Clone,
232    Debug,
233    Eq,
234    PartialEq,
235    Serialize,
236    Deserialize,
237    Hash,
238    MzReflect
239)]
240pub struct ArrayFill {
241    pub elem_type: SqlScalarType,
242}
243
244#[sqlfunc(
245    ArrayFill,
246    output_type_expr = "SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false)",
247    introduces_nulls = false
248)]
249fn array_fill<'a>(
250    &self,
251    fill: Datum<'a>,
252    dims: Option<Array<'a>>,
253    lower_bounds: OptionalArg<Option<Array<'a>>>,
254    temp_storage: &'a RowArena,
255) -> Result<Datum<'a>, EvalError> {
256    const MAX_SIZE: usize = 1 << 28 - 1;
257    const NULL_ARR_ERR: &str = "dimension array or low bound array";
258    const NULL_ELEM_ERR: &str = "dimension values";
259
260    if matches!(fill, Datum::Array(_)) {
261        return Err(EvalError::Unsupported {
262            feature: "array_fill with arrays".into(),
263            discussion_no: None,
264        });
265    }
266
267    let Some(arr) = dims else {
268        return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
269    };
270
271    let dimensions = arr
272        .elements()
273        .iter()
274        .map(|d| match d {
275            Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
276            d => Ok(usize::cast_from(u32::reinterpret_cast(d.unwrap_int32()))),
277        })
278        .collect::<Result<Vec<_>, _>>()?;
279
280    let lower_bounds = match *lower_bounds {
281        Some(d) => {
282            let Some(arr) = d else {
283                return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
284            };
285
286            arr.elements()
287                .iter()
288                .map(|l| match l {
289                    Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
290                    l => Ok(isize::cast_from(l.unwrap_int32())),
291                })
292                .collect::<Result<Vec<_>, _>>()?
293        }
294        None => {
295            vec![1isize; dimensions.len()]
296        }
297    };
298
299    if lower_bounds.len() != dimensions.len() {
300        return Err(EvalError::ArrayFillWrongArraySubscripts);
301    }
302
303    let fill_count: usize = dimensions
304        .iter()
305        .cloned()
306        .map(Some)
307        .reduce(|a, b| match (a, b) {
308            (Some(a), Some(b)) => a.checked_mul(b),
309            _ => None,
310        })
311        .flatten()
312        .ok_or(EvalError::MaxArraySizeExceeded(MAX_SIZE))?;
313
314    if matches!(
315        mz_repr::datum_size(&fill).checked_mul(fill_count),
316        None | Some(MAX_SIZE..)
317    ) {
318        return Err(EvalError::MaxArraySizeExceeded(MAX_SIZE));
319    }
320
321    let array_dimensions = if fill_count == 0 {
322        vec![ArrayDimension {
323            lower_bound: 1,
324            length: 0,
325        }]
326    } else {
327        dimensions
328            .into_iter()
329            .zip_eq(lower_bounds)
330            .map(|(length, lower_bound)| ArrayDimension {
331                lower_bound,
332                length,
333            })
334            .collect()
335    };
336
337    Ok(temp_storage.try_make_datum(|packer| {
338        packer.try_push_array(&array_dimensions, vec![fill; fill_count])
339    })?)
340}
341
342#[derive(
343    Ord,
344    PartialOrd,
345    Clone,
346    Debug,
347    Eq,
348    PartialEq,
349    Serialize,
350    Deserialize,
351    Hash,
352    MzReflect
353)]
354pub struct ArrayIndex {
355    pub offset: i64,
356}
357#[sqlfunc(
358    ArrayIndex,
359    sqlname = "array_index",
360    output_type_expr = "input_types[0].scalar_type.unwrap_array_element_type().clone().nullable(true)",
361    introduces_nulls = true
362)]
363fn array_index<'a>(&self, array: Array<'a>, indices: Variadic<i64>) -> Option<Datum<'a>> {
364    mz_ore::soft_assert_no_log!(
365        self.offset == 0 || self.offset == 1,
366        "offset must be either 0 or 1"
367    );
368
369    let dims = array.dims();
370    if dims.len() != indices.len() {
371        // You missed the datums "layer"
372        return None;
373    }
374
375    let mut final_idx = 0;
376
377    for (d, idx) in dims.into_iter().zip_eq(indices.iter()) {
378        // Lower bound is written in terms of 1-based indexing, which offset accounts for.
379        let idx = isize::cast_from(*idx + self.offset);
380
381        let (lower, upper) = d.dimension_bounds();
382
383        // This index missed all of the data at this layer. The dimension bounds are inclusive,
384        // while range checks are exclusive, so adjust.
385        if !(lower..upper + 1).contains(&idx) {
386            return None;
387        }
388
389        // We discover how many indices our last index represents physically.
390        final_idx *= d.length;
391
392        // Because both index and lower bound are handled in 1-based indexing, taking their
393        // difference moves us back into 0-based indexing. Similarly, if the lower bound is
394        // negative, subtracting a negative value >= to itself ensures its non-negativity.
395        final_idx += usize::try_from(idx - d.lower_bound)
396            .expect("previous bounds check ensures physical index is at least 0");
397    }
398
399    array.elements().iter().nth(final_idx)
400}
401
402#[sqlfunc]
403fn array_position<'a>(
404    array: Array<'a>,
405    search: Datum<'a>,
406    initial_pos: OptionalArg<Option<i32>>,
407) -> Result<Option<i32>, EvalError> {
408    if array.dims().len() > 1 {
409        return Err(EvalError::MultiDimensionalArraySearch);
410    }
411
412    if search == Datum::Null {
413        return Ok(None);
414    }
415
416    let skip = match initial_pos.0 {
417        None => 0,
418        Some(None) => return Err(EvalError::MustNotBeNull("initial position".into())),
419        Some(Some(o)) => usize::try_from(o).unwrap_or(0).saturating_sub(1),
420    };
421
422    let Some(r) = array.elements().iter().skip(skip).position(|d| d == search) else {
423        return Ok(None);
424    };
425
426    // Adjust count for the amount we skipped, plus 1 for adjusting to PG indexing scheme.
427    let p = i32::try_from(r + skip + 1).expect("fewer than i32::MAX elements in array");
428    Ok(Some(p))
429}
430
431#[derive(
432    Ord,
433    PartialOrd,
434    Clone,
435    Debug,
436    Eq,
437    PartialEq,
438    Serialize,
439    Deserialize,
440    Hash,
441    MzReflect
442)]
443pub struct ArrayToString {
444    pub elem_type: SqlScalarType,
445}
446
447#[sqlfunc]
448fn array_to_string<'a>(
449    &self,
450    array: Array<'a>,
451    delimiter: &str,
452    null_str_arg: OptionalArg<Option<&str>>,
453) -> Result<String, EvalError> {
454    // `flatten` treats absent arguments (`None`) the same as explicit NULL
455    // (`Some(None)`), both becoming `None`.
456    let null_str = null_str_arg.flatten();
457    let mut out = String::new();
458    for elem in array.elements().iter() {
459        if elem.is_null() {
460            if let Some(null_str) = null_str {
461                out.push_str(null_str);
462                out.push_str(delimiter);
463            }
464        } else {
465            stringify_datum(&mut out, elem, &self.elem_type)?;
466            out.push_str(delimiter);
467        }
468    }
469    if out.len() > 0 {
470        // Lop off last delimiter only if string is not empty
471        out.truncate(out.len() - delimiter.len());
472    }
473    Ok(out)
474}
475
476#[derive(
477    Ord,
478    PartialOrd,
479    Clone,
480    Debug,
481    Eq,
482    PartialEq,
483    Serialize,
484    Deserialize,
485    Hash,
486    MzReflect
487)]
488pub struct Coalesce;
489
490impl fmt::Display for Coalesce {
491    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
492        f.write_str("coalesce")
493    }
494}
495
496impl LazyVariadicFunc for Coalesce {
497    fn eval<'a>(
498        &'a self,
499        datums: &[Datum<'a>],
500        temp_storage: &'a RowArena,
501        exprs: &'a [MirScalarExpr],
502    ) -> Result<Datum<'a>, EvalError> {
503        for e in exprs {
504            let d = e.eval(datums, temp_storage)?;
505            if !d.is_null() {
506                return Ok(d);
507            }
508        }
509        Ok(Datum::Null)
510    }
511
512    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
513        // Note that the parser doesn't allow empty argument lists for variadic functions
514        // that use the standard function call syntax (ArrayCreate and co. are different
515        // because of the special syntax for calling them).
516        let nullable = input_types.iter().all(|typ| typ.nullable);
517        SqlColumnType::union_many(input_types).nullable(nullable)
518    }
519
520    fn propagates_nulls(&self) -> bool {
521        false
522    }
523
524    fn introduces_nulls(&self) -> bool {
525        false
526    }
527
528    fn could_error(&self) -> bool {
529        false
530    }
531
532    fn is_monotone(&self) -> bool {
533        true
534    }
535
536    fn is_associative(&self) -> bool {
537        true
538    }
539}
540
541#[derive(
542    Ord,
543    PartialOrd,
544    Clone,
545    Debug,
546    Eq,
547    PartialEq,
548    Serialize,
549    Deserialize,
550    Hash,
551    MzReflect
552)]
553pub struct RangeCreate {
554    pub elem_type: SqlScalarType,
555}
556
557impl fmt::Display for RangeCreate {
558    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
559        f.write_str(match &self.elem_type {
560            SqlScalarType::Int32 => "int4range",
561            SqlScalarType::Int64 => "int8range",
562            SqlScalarType::Date => "daterange",
563            SqlScalarType::Numeric { .. } => "numrange",
564            SqlScalarType::Timestamp { .. } => "tsrange",
565            SqlScalarType::TimestampTz { .. } => "tstzrange",
566            _ => unreachable!(),
567        })
568    }
569}
570
571impl EagerVariadicFunc for RangeCreate {
572    type Input<'a> = (Datum<'a>, Datum<'a>, Datum<'a>);
573    type Output<'a> = Result<Datum<'a>, EvalError>;
574
575    fn call<'a>(
576        &self,
577        (lower, upper, flags_datum): Self::Input<'a>,
578        temp_storage: &'a RowArena,
579    ) -> Self::Output<'a> {
580        let flags = match flags_datum {
581            Datum::Null => {
582                return Err(EvalError::InvalidRange(
583                    InvalidRangeError::NullRangeBoundFlags,
584                ));
585            }
586            o => o.unwrap_str(),
587        };
588
589        let (lower_inclusive, upper_inclusive) = parse_range_bound_flags(flags)?;
590
591        let mut range = Range::new(Some((
592            RangeBound::new(lower, lower_inclusive),
593            RangeBound::new(upper, upper_inclusive),
594        )));
595
596        range.canonicalize()?;
597
598        Ok(temp_storage.make_datum(|row| {
599            row.push_range(range).expect("errors already handled");
600        }))
601    }
602
603    fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
604        SqlScalarType::Range {
605            element_type: Box::new(self.elem_type.clone()),
606        }
607        .nullable(false)
608    }
609
610    fn introduces_nulls(&self) -> bool {
611        false
612    }
613}
614
615#[sqlfunc(sqlname = "datediff")]
616fn date_diff_date(unit_str: &str, a: Date, b: Date) -> Result<i64, EvalError> {
617    let unit = unit_str
618        .parse()
619        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
620
621    // Convert the Date into a timestamp so we can calculate age.
622    let a_ts = CheckedTimestamp::try_from(NaiveDate::from(a).and_hms_opt(0, 0, 0).unwrap())?;
623    let b_ts = CheckedTimestamp::try_from(NaiveDate::from(b).and_hms_opt(0, 0, 0).unwrap())?;
624    let diff = b_ts.diff_as(&a_ts, unit)?;
625    Ok(diff)
626}
627
628#[sqlfunc(sqlname = "datediff")]
629fn date_diff_time(unit_str: &str, a: NaiveTime, b: NaiveTime) -> Result<i64, EvalError> {
630    let unit = unit_str
631        .parse()
632        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
633
634    // Convert the Time into a timestamp so we can calculate age.
635    let a_ts =
636        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(a))?;
637    let b_ts =
638        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(b))?;
639    let diff = b_ts.diff_as(&a_ts, unit)?;
640    Ok(diff)
641}
642
643#[sqlfunc(sqlname = "datediff")]
644fn date_diff_timestamp(
645    unit: &str,
646    a: CheckedTimestamp<NaiveDateTime>,
647    b: CheckedTimestamp<NaiveDateTime>,
648) -> Result<i64, EvalError> {
649    let unit = unit
650        .parse()
651        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
652
653    let diff = b.diff_as(&a, unit)?;
654    Ok(diff)
655}
656
657#[sqlfunc(sqlname = "datediff")]
658fn date_diff_timestamp_tz(
659    unit: &str,
660    a: CheckedTimestamp<DateTime<Utc>>,
661    b: CheckedTimestamp<DateTime<Utc>>,
662) -> Result<i64, EvalError> {
663    let unit = unit
664        .parse()
665        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
666
667    let diff = b.diff_as(&a, unit)?;
668    Ok(diff)
669}
670
671#[derive(
672    Ord,
673    PartialOrd,
674    Clone,
675    Debug,
676    Eq,
677    PartialEq,
678    Serialize,
679    Deserialize,
680    Hash,
681    MzReflect
682)]
683pub struct ErrorIfNull;
684
685impl fmt::Display for ErrorIfNull {
686    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
687        f.write_str("error_if_null")
688    }
689}
690
691impl LazyVariadicFunc for ErrorIfNull {
692    fn eval<'a>(
693        &'a self,
694        datums: &[Datum<'a>],
695        temp_storage: &'a RowArena,
696        exprs: &'a [MirScalarExpr],
697    ) -> Result<Datum<'a>, EvalError> {
698        let first = exprs[0].eval(datums, temp_storage)?;
699        match first {
700            Datum::Null => {
701                let err_msg = match exprs[1].eval(datums, temp_storage)? {
702                    Datum::Null => {
703                        return Err(EvalError::Internal(
704                            "unexpected NULL in error side of error_if_null".into(),
705                        ));
706                    }
707                    o => o.unwrap_str(),
708                };
709                Err(EvalError::IfNullError(err_msg.into()))
710            }
711            _ => Ok(first),
712        }
713    }
714
715    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
716        input_types[0].scalar_type.clone().nullable(false)
717    }
718
719    fn propagates_nulls(&self) -> bool {
720        false
721    }
722
723    fn introduces_nulls(&self) -> bool {
724        false
725    }
726}
727
728#[derive(
729    Ord,
730    PartialOrd,
731    Clone,
732    Debug,
733    Eq,
734    PartialEq,
735    Serialize,
736    Deserialize,
737    Hash,
738    MzReflect
739)]
740pub struct Greatest;
741
742impl fmt::Display for Greatest {
743    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
744        f.write_str("greatest")
745    }
746}
747
748impl LazyVariadicFunc for Greatest {
749    fn eval<'a>(
750        &'a self,
751        datums: &[Datum<'a>],
752        temp_storage: &'a RowArena,
753        exprs: &'a [MirScalarExpr],
754    ) -> Result<Datum<'a>, EvalError> {
755        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
756        Ok(datums
757            .filter(|d| Ok(!d.is_null()))
758            .max()?
759            .unwrap_or(Datum::Null))
760    }
761
762    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
763        SqlColumnType::union_many(input_types)
764    }
765
766    fn propagates_nulls(&self) -> bool {
767        false
768    }
769
770    fn introduces_nulls(&self) -> bool {
771        false
772    }
773
774    fn could_error(&self) -> bool {
775        false
776    }
777
778    fn is_monotone(&self) -> bool {
779        true
780    }
781
782    fn is_associative(&self) -> bool {
783        true
784    }
785}
786
787#[sqlfunc(sqlname = "hmac")]
788fn hmac_string(to_digest: &str, key: &str, typ: &str) -> Result<Vec<u8>, EvalError> {
789    let to_digest = to_digest.as_bytes();
790    let key = key.as_bytes();
791    hmac_inner(to_digest, key, typ)
792}
793
794#[sqlfunc(sqlname = "hmac")]
795fn hmac_bytes(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
796    hmac_inner(to_digest, key, typ)
797}
798
799pub fn hmac_inner(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
800    match typ {
801        "md5" => {
802            let mut mac = Hmac::<Md5>::new_from_slice(key).expect("HMAC accepts any key size");
803            mac.update(to_digest);
804            Ok(mac.finalize().into_bytes().to_vec())
805        }
806        "sha1" => {
807            let mut mac = Hmac::<Sha1>::new_from_slice(key).expect("HMAC accepts any key size");
808            mac.update(to_digest);
809            Ok(mac.finalize().into_bytes().to_vec())
810        }
811        "sha224" => {
812            let mut mac = Hmac::<Sha224>::new_from_slice(key).expect("HMAC accepts any key size");
813            mac.update(to_digest);
814            Ok(mac.finalize().into_bytes().to_vec())
815        }
816        "sha256" => {
817            let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC accepts any key size");
818            mac.update(to_digest);
819            Ok(mac.finalize().into_bytes().to_vec())
820        }
821        "sha384" => {
822            let mut mac = Hmac::<Sha384>::new_from_slice(key).expect("HMAC accepts any key size");
823            mac.update(to_digest);
824            Ok(mac.finalize().into_bytes().to_vec())
825        }
826        "sha512" => {
827            let mut mac = Hmac::<Sha512>::new_from_slice(key).expect("HMAC accepts any key size");
828            mac.update(to_digest);
829            Ok(mac.finalize().into_bytes().to_vec())
830        }
831        other => Err(EvalError::InvalidHashAlgorithm(other.into())),
832    }
833}
834
835#[sqlfunc]
836fn jsonb_build_array<'a>(datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> JsonbRef<'a> {
837    let datum = temp_storage.make_datum(|packer| {
838        packer.push_list(datums.into_iter().map(|d| match d {
839            Datum::Null => Datum::JsonNull,
840            d => d,
841        }))
842    });
843    JsonbRef::from_datum(datum)
844}
845
846#[sqlfunc]
847fn jsonb_build_object<'a>(
848    mut kvs: Variadic<(Datum<'a>, Datum<'a>)>,
849    temp_storage: &'a RowArena,
850) -> Result<JsonbRef<'a>, EvalError> {
851    kvs.0.sort_by(|kv1, kv2| kv1.0.cmp(&kv2.0));
852    kvs.0.dedup_by(|kv1, kv2| kv1.0 == kv2.0);
853    let datum = temp_storage.try_make_datum(|packer| {
854        packer.push_dict_with(|packer| {
855            for (k, v) in kvs {
856                if k.is_null() {
857                    return Err(EvalError::KeyCannotBeNull);
858                }
859                let v = match v {
860                    Datum::Null => Datum::JsonNull,
861                    d => d,
862                };
863                packer.push(k);
864                packer.push(v);
865            }
866            Ok(())
867        })
868    })?;
869    Ok(JsonbRef::from_datum(datum))
870}
871
872#[derive(
873    Ord,
874    PartialOrd,
875    Clone,
876    Debug,
877    Eq,
878    PartialEq,
879    Serialize,
880    Deserialize,
881    Hash,
882    MzReflect
883)]
884pub struct Least;
885
886impl fmt::Display for Least {
887    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
888        f.write_str("least")
889    }
890}
891
892impl LazyVariadicFunc for Least {
893    fn eval<'a>(
894        &'a self,
895        datums: &[Datum<'a>],
896        temp_storage: &'a RowArena,
897        exprs: &'a [MirScalarExpr],
898    ) -> Result<Datum<'a>, EvalError> {
899        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
900        Ok(datums
901            .filter(|d| Ok(!d.is_null()))
902            .min()?
903            .unwrap_or(Datum::Null))
904    }
905
906    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
907        SqlColumnType::union_many(input_types)
908    }
909
910    fn propagates_nulls(&self) -> bool {
911        false
912    }
913
914    fn introduces_nulls(&self) -> bool {
915        false
916    }
917
918    fn could_error(&self) -> bool {
919        false
920    }
921
922    fn is_monotone(&self) -> bool {
923        true
924    }
925
926    fn is_associative(&self) -> bool {
927        true
928    }
929}
930
931#[derive(
932    Ord,
933    PartialOrd,
934    Clone,
935    Debug,
936    Eq,
937    PartialEq,
938    Serialize,
939    Deserialize,
940    Hash,
941    MzReflect
942)]
943pub struct ListCreate {
944    pub elem_type: SqlScalarType,
945}
946
947#[sqlfunc(
948    output_type_expr = "SqlScalarType::List { element_type: Box::new(self.elem_type.clone()), custom_id: None }.nullable(false)",
949    introduces_nulls = false
950)]
951fn list_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
952    temp_storage.make_datum(|packer| packer.push_list(datums))
953}
954
955#[derive(
956    Ord,
957    PartialOrd,
958    Clone,
959    Debug,
960    Eq,
961    PartialEq,
962    Serialize,
963    Deserialize,
964    Hash,
965    MzReflect
966)]
967pub struct RecordCreate {
968    pub field_names: Vec<ColumnName>,
969}
970
971#[sqlfunc(
972    output_type_expr = "SqlScalarType::Record { fields: self.field_names.clone().into_iter().zip_eq(input_types.iter().cloned()).collect(), custom_id: None }.nullable(false)",
973    introduces_nulls = false
974)]
975fn record_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
976    temp_storage.make_datum(|packer| packer.push_list(datums.iter().copied()))
977}
978
979#[sqlfunc(
980    output_type_expr = "input_types[0].scalar_type.unwrap_list_nth_layer_type(input_types.len() - 1).clone().nullable(true)",
981    introduces_nulls = true
982)]
983// TODO(benesch): remove potentially dangerous usage of `as`.
984#[allow(clippy::as_conversions)]
985fn list_index<'a>(buf: DatumList<'a>, indices: Variadic<i64>) -> Datum<'a> {
986    let mut buf = Datum::List(buf);
987    for i in indices {
988        if buf.is_null() {
989            break;
990        }
991        if i < 1 {
992            return Datum::Null;
993        }
994
995        buf = match buf.unwrap_list().iter().nth(i as usize - 1) {
996            Some(datum) => datum,
997            None => return Datum::Null,
998        }
999    }
1000    buf
1001}
1002
1003#[sqlfunc(sqlname = "makeaclitem")]
1004fn make_acl_item(
1005    grantee_oid: u32,
1006    grantor_oid: u32,
1007    privileges: &str,
1008    is_grantable: bool,
1009) -> Result<AclItem, EvalError> {
1010    let grantee = Oid(grantee_oid);
1011    let grantor = Oid(grantor_oid);
1012    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1013        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1014    if is_grantable {
1015        return Err(EvalError::Unsupported {
1016            feature: "GRANT OPTION".into(),
1017            discussion_no: None,
1018        });
1019    }
1020
1021    Ok(AclItem {
1022        grantee,
1023        grantor,
1024        acl_mode,
1025    })
1026}
1027
1028#[sqlfunc(sqlname = "make_mz_aclitem")]
1029fn make_mz_acl_item(
1030    grantee_str: &str,
1031    grantor_str: &str,
1032    privileges: &str,
1033) -> Result<MzAclItem, EvalError> {
1034    let grantee: RoleId = grantee_str
1035        .parse()
1036        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1037    let grantor: RoleId = grantor_str
1038        .parse()
1039        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1040    if grantor == RoleId::Public {
1041        return Err(EvalError::InvalidRoleId(
1042            "mz_aclitem grantor cannot be PUBLIC role".into(),
1043        ));
1044    }
1045    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1046        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1047
1048    Ok(MzAclItem {
1049        grantee,
1050        grantor,
1051        acl_mode,
1052    })
1053}
1054
1055#[sqlfunc(sqlname = "makets")]
1056// TODO(benesch): remove potentially dangerous usage of `as`.
1057#[allow(clippy::as_conversions)]
1058fn make_timestamp(
1059    year: i64,
1060    month: i64,
1061    day: i64,
1062    hour: i64,
1063    minute: i64,
1064    second_float: f64,
1065) -> Result<Option<CheckedTimestamp<NaiveDateTime>>, EvalError> {
1066    let year: i32 = match year.try_into() {
1067        Ok(year) => year,
1068        Err(_) => return Ok(None),
1069    };
1070    let month: u32 = match month.try_into() {
1071        Ok(month) => month,
1072        Err(_) => return Ok(None),
1073    };
1074    let day: u32 = match day.try_into() {
1075        Ok(day) => day,
1076        Err(_) => return Ok(None),
1077    };
1078    let hour: u32 = match hour.try_into() {
1079        Ok(hour) => hour,
1080        Err(_) => return Ok(None),
1081    };
1082    let minute: u32 = match minute.try_into() {
1083        Ok(minute) => minute,
1084        Err(_) => return Ok(None),
1085    };
1086    let second = second_float as u32;
1087    let micros = ((second_float - second as f64) * 1_000_000.0) as u32;
1088    let date = match NaiveDate::from_ymd_opt(year, month, day) {
1089        Some(date) => date,
1090        None => return Ok(None),
1091    };
1092    let timestamp = match date.and_hms_micro_opt(hour, minute, second, micros) {
1093        Some(timestamp) => timestamp,
1094        None => return Ok(None),
1095    };
1096    Ok(Some(timestamp.try_into()?))
1097}
1098
1099#[derive(
1100    Ord,
1101    PartialOrd,
1102    Clone,
1103    Debug,
1104    Eq,
1105    PartialEq,
1106    Serialize,
1107    Deserialize,
1108    Hash,
1109    MzReflect
1110)]
1111pub struct MapBuild {
1112    pub value_type: SqlScalarType,
1113}
1114
1115#[sqlfunc(
1116    output_type_expr = "SqlScalarType::Map { value_type: Box::new(self.value_type.clone()), custom_id: None }.nullable(false)",
1117    introduces_nulls = false
1118)]
1119fn map_build<'a>(
1120    &self,
1121    datums: Variadic<(Option<&str>, Datum<'a>)>,
1122    temp_storage: &'a RowArena,
1123) -> Datum<'a> {
1124    // Collect into a `BTreeMap` to provide the same semantics as it.
1125    let map: std::collections::BTreeMap<&str, _> = datums
1126        .into_iter()
1127        .filter_map(|(k, v)| k.map(|k| (k, v)))
1128        .collect();
1129
1130    temp_storage.make_datum(|packer| packer.push_dict(map))
1131}
1132
1133#[derive(
1134    Ord,
1135    PartialOrd,
1136    Clone,
1137    Debug,
1138    Eq,
1139    PartialEq,
1140    Serialize,
1141    Deserialize,
1142    Hash,
1143    MzReflect
1144)]
1145pub struct Or;
1146
1147impl fmt::Display for Or {
1148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1149        f.write_str("OR")
1150    }
1151}
1152
1153impl LazyVariadicFunc for Or {
1154    fn eval<'a>(
1155        &'a self,
1156        datums: &[Datum<'a>],
1157        temp_storage: &'a RowArena,
1158        exprs: &'a [MirScalarExpr],
1159    ) -> Result<Datum<'a>, EvalError> {
1160        // If any is true, then return true. Else, if any is null, then return null. Else, return false.
1161        let mut null = false;
1162        let mut err = None;
1163        for expr in exprs {
1164            match expr.eval(datums, temp_storage) {
1165                Ok(Datum::False) => {}
1166                Ok(Datum::True) => return Ok(Datum::True), // short-circuit
1167                // No return in these two cases, because we might still see a true
1168                Ok(Datum::Null) => null = true,
1169                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
1170                _ => unreachable!(),
1171            }
1172        }
1173        match (err, null) {
1174            (Some(err), _) => Err(err),
1175            (None, true) => Ok(Datum::Null),
1176            (None, false) => Ok(Datum::False),
1177        }
1178    }
1179
1180    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1181        let in_nullable = input_types.iter().any(|t| t.nullable);
1182        SqlScalarType::Bool.nullable(in_nullable)
1183    }
1184
1185    fn propagates_nulls(&self) -> bool {
1186        false
1187    }
1188
1189    fn introduces_nulls(&self) -> bool {
1190        false
1191    }
1192
1193    fn could_error(&self) -> bool {
1194        false
1195    }
1196
1197    fn is_monotone(&self) -> bool {
1198        true
1199    }
1200
1201    fn is_associative(&self) -> bool {
1202        true
1203    }
1204
1205    fn is_infix_op(&self) -> bool {
1206        true
1207    }
1208}
1209
1210#[sqlfunc(sqlname = "lpad")]
1211fn pad_leading(string: &str, raw_len: i32, pad: OptionalArg<&str>) -> Result<String, EvalError> {
1212    let len = match usize::try_from(raw_len) {
1213        Ok(len) => len,
1214        Err(_) => {
1215            return Err(EvalError::InvalidParameterValue(
1216                "length must be nonnegative".into(),
1217            ));
1218        }
1219    };
1220    if len > MAX_STRING_FUNC_RESULT_BYTES {
1221        return Err(EvalError::LengthTooLarge);
1222    }
1223
1224    let pad_string = pad.unwrap_or(" ");
1225
1226    let (end_char, end_char_byte_offset) = string
1227        .chars()
1228        .take(len)
1229        .fold((0, 0), |acc, char| (acc.0 + 1, acc.1 + char.len_utf8()));
1230
1231    let mut buf = String::with_capacity(len);
1232    if len == end_char {
1233        buf.push_str(&string[0..end_char_byte_offset]);
1234    } else {
1235        buf.extend(pad_string.chars().cycle().take(len - end_char));
1236        buf.push_str(string);
1237    }
1238
1239    Ok(buf)
1240}
1241
1242#[sqlfunc(
1243    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true)",
1244    introduces_nulls = true
1245)]
1246fn regexp_match<'a>(
1247    haystack: &'a str,
1248    needle: &str,
1249    flags: OptionalArg<&str>,
1250    temp_storage: &'a RowArena,
1251) -> Result<Datum<'a>, EvalError> {
1252    let flags = flags.unwrap_or("");
1253    let needle = build_regex(needle, flags)?;
1254    regexp_match_static(Datum::String(haystack), temp_storage, &needle)
1255}
1256
1257#[sqlfunc(
1258    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1259    introduces_nulls = false
1260)]
1261fn regexp_split_to_array<'a>(
1262    text: &str,
1263    regexp_str: &str,
1264    flags: OptionalArg<&str>,
1265    temp_storage: &'a RowArena,
1266) -> Result<Datum<'a>, EvalError> {
1267    let flags = flags.unwrap_or("");
1268    let regexp = build_regex(regexp_str, flags)?;
1269    regexp_split_to_array_re(text, &regexp, temp_storage)
1270}
1271
1272#[sqlfunc]
1273fn regexp_replace<'a>(
1274    source: &'a str,
1275    pattern: &str,
1276    replacement: &str,
1277    flags_opt: OptionalArg<&str>,
1278) -> Result<Cow<'a, str>, EvalError> {
1279    let flags = flags_opt.0.unwrap_or("");
1280    let (limit, flags) = regexp_replace_parse_flags(flags);
1281    let regexp = build_regex(pattern, &flags)?;
1282    Ok(regexp.replacen(source, limit, replacement))
1283}
1284
1285#[sqlfunc]
1286fn replace(text: &str, from: &str, to: &str) -> Result<String, EvalError> {
1287    // As a compromise to avoid always nearly duplicating the work of replace by doing size estimation,
1288    // we first check if it's possible for the fully replaced string to exceed the limit by assuming that
1289    // every possible substring is replaced.
1290    //
1291    // If that estimate exceeds the limit, we then do a more precise (and expensive) estimate by counting
1292    // the actual number of replacements that would occur, and using that to calculate the final size.
1293    let possible_size = text.len() * to.len();
1294    if possible_size > MAX_STRING_FUNC_RESULT_BYTES {
1295        let replacement_count = text.matches(from).count();
1296        let estimated_size = text.len() + replacement_count * (to.len().saturating_sub(from.len()));
1297        if estimated_size > MAX_STRING_FUNC_RESULT_BYTES {
1298            return Err(EvalError::LengthTooLarge);
1299        }
1300    }
1301
1302    Ok(text.replace(from, to))
1303}
1304
1305#[sqlfunc(
1306    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1307    introduces_nulls = false,
1308    propagates_nulls = false
1309)]
1310fn string_to_array<'a>(
1311    string: &'a str,
1312    delimiter: Option<&'a str>,
1313    null_string: OptionalArg<Option<&'a str>>,
1314    temp_storage: &'a RowArena,
1315) -> Result<Datum<'a>, EvalError> {
1316    if string.is_empty() {
1317        let mut row = Row::default();
1318        let mut packer = row.packer();
1319        packer.try_push_array(&[], std::iter::empty::<Datum>())?;
1320
1321        return Ok(temp_storage.push_unary_row(row));
1322    }
1323
1324    let Some(delimiter) = delimiter else {
1325        let split_all_chars_delimiter = "";
1326        return string_to_array_impl(
1327            string,
1328            split_all_chars_delimiter,
1329            null_string.flatten(),
1330            temp_storage,
1331        );
1332    };
1333
1334    if delimiter.is_empty() {
1335        let mut row = Row::default();
1336        let mut packer = row.packer();
1337        packer.try_push_array(
1338            &[ArrayDimension {
1339                lower_bound: 1,
1340                length: 1,
1341            }],
1342            vec![string].into_iter().map(Datum::String),
1343        )?;
1344
1345        Ok(temp_storage.push_unary_row(row))
1346    } else {
1347        string_to_array_impl(string, delimiter, null_string.flatten(), temp_storage)
1348    }
1349}
1350
1351fn string_to_array_impl<'a>(
1352    string: &str,
1353    delimiter: &str,
1354    null_string: Option<&'a str>,
1355    temp_storage: &'a RowArena,
1356) -> Result<Datum<'a>, EvalError> {
1357    let mut row = Row::default();
1358    let mut packer = row.packer();
1359
1360    let result = string.split(delimiter);
1361    let found: Vec<&str> = if delimiter.is_empty() {
1362        result.filter(|s| !s.is_empty()).collect()
1363    } else {
1364        result.collect()
1365    };
1366    let array_dimensions = [ArrayDimension {
1367        lower_bound: 1,
1368        length: found.len(),
1369    }];
1370
1371    if let Some(null_string) = null_string {
1372        let found_datums = found.into_iter().map(|chunk| {
1373            if chunk.eq(null_string) {
1374                Datum::Null
1375            } else {
1376                Datum::String(chunk)
1377            }
1378        });
1379
1380        packer.try_push_array(&array_dimensions, found_datums)?;
1381    } else {
1382        packer.try_push_array(&array_dimensions, found.into_iter().map(Datum::String))?;
1383    }
1384
1385    Ok(temp_storage.push_unary_row(row))
1386}
1387
1388#[sqlfunc]
1389fn substr<'a>(s: &'a str, start: i32, length: OptionalArg<i32>) -> Result<&'a str, EvalError> {
1390    let raw_start_idx = i64::from(start) - 1;
1391    let start_idx = match usize::try_from(cmp::max(raw_start_idx, 0)) {
1392        Ok(i) => i,
1393        Err(_) => {
1394            return Err(EvalError::InvalidParameterValue(
1395                format!(
1396                    "substring starting index ({}) exceeds min/max position",
1397                    raw_start_idx
1398                )
1399                .into(),
1400            ));
1401        }
1402    };
1403
1404    let mut char_indices = s.char_indices();
1405    let get_str_index = |(index, _char)| index;
1406
1407    let str_len = s.len();
1408    let start_char_idx = char_indices.nth(start_idx).map_or(str_len, get_str_index);
1409
1410    if let OptionalArg(Some(len)) = length {
1411        let end_idx = match i64::from(len) {
1412            e if e < 0 => {
1413                return Err(EvalError::InvalidParameterValue(
1414                    "negative substring length not allowed".into(),
1415                ));
1416            }
1417            e if e == 0 || e + raw_start_idx < 1 => return Ok(""),
1418            e => {
1419                let e = cmp::min(raw_start_idx + e - 1, e - 1);
1420                match usize::try_from(e) {
1421                    Ok(i) => i,
1422                    Err(_) => {
1423                        return Err(EvalError::InvalidParameterValue(
1424                            format!("substring length ({}) exceeds max position", e).into(),
1425                        ));
1426                    }
1427                }
1428            }
1429        };
1430
1431        let end_char_idx = char_indices.nth(end_idx).map_or(str_len, get_str_index);
1432
1433        Ok(&s[start_char_idx..end_char_idx])
1434    } else {
1435        Ok(&s[start_char_idx..])
1436    }
1437}
1438
1439#[sqlfunc(sqlname = "split_string")]
1440fn split_part<'a>(string: &'a str, delimiter: &str, field: i32) -> Result<&'a str, EvalError> {
1441    let index = match usize::try_from(i64::from(field) - 1) {
1442        Ok(index) => index,
1443        Err(_) => {
1444            return Err(EvalError::InvalidParameterValue(
1445                "field position must be greater than zero".into(),
1446            ));
1447        }
1448    };
1449
1450    // If the provided delimiter is the empty string,
1451    // PostgreSQL does not break the string into individual
1452    // characters. Instead, it generates the following parts: [string].
1453    if delimiter.is_empty() {
1454        if index == 0 {
1455            return Ok(string);
1456        } else {
1457            return Ok("");
1458        }
1459    }
1460
1461    // If provided index is greater than the number of split parts,
1462    // return an empty string.
1463    Ok(string.split(delimiter).nth(index).unwrap_or(""))
1464}
1465
1466#[sqlfunc(is_associative = true)]
1467fn concat(strs: Variadic<Option<&str>>) -> Result<String, EvalError> {
1468    let mut total_size = 0;
1469    for s in &strs {
1470        if let Some(s) = s {
1471            total_size += s.len();
1472            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1473                return Err(EvalError::LengthTooLarge);
1474            }
1475        }
1476    }
1477    let mut buf = String::with_capacity(total_size);
1478    for s in strs {
1479        if let Some(s) = s {
1480            buf.push_str(s);
1481        }
1482    }
1483    Ok(buf)
1484}
1485
1486#[sqlfunc]
1487fn concat_ws(ws: &str, rest: Variadic<Option<&str>>) -> Result<String, EvalError> {
1488    let mut total_size = 0;
1489    for s in &rest {
1490        if let Some(s) = s {
1491            total_size += s.len();
1492            total_size += ws.len();
1493            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1494                return Err(EvalError::LengthTooLarge);
1495            }
1496        }
1497    }
1498
1499    let buf = Itertools::join(&mut rest.into_iter().filter_map(|s| s), ws);
1500
1501    Ok(buf)
1502}
1503
1504#[sqlfunc]
1505fn translate(string: &str, from_str: &str, to_str: &str) -> String {
1506    let from = from_str.chars().collect::<Vec<_>>();
1507    let to = to_str.chars().collect::<Vec<_>>();
1508
1509    string
1510        .chars()
1511        .filter_map(|c| match from.iter().position(|f| f == &c) {
1512            Some(idx) => to.get(idx).copied(),
1513            None => Some(c),
1514        })
1515        .collect()
1516}
1517
1518#[sqlfunc(
1519    output_type_expr = "input_types[0].scalar_type.clone().nullable(false)",
1520    introduces_nulls = false
1521)]
1522// TODO(benesch): remove potentially dangerous usage of `as`.
1523#[allow(clippy::as_conversions)]
1524fn list_slice_linear<'a>(
1525    list: DatumList<'a>,
1526    first: (i64, i64),
1527    remainder: Variadic<(i64, i64)>,
1528    temp_storage: &'a RowArena,
1529) -> Datum<'a> {
1530    let mut start_idx = 0;
1531    let mut total_length = usize::MAX;
1532
1533    for (start, end) in std::iter::once(first).chain(remainder) {
1534        let start = std::cmp::max(start, 1);
1535
1536        // Result should be empty list.
1537        if start > end {
1538            start_idx = 0;
1539            total_length = 0;
1540            break;
1541        }
1542
1543        let start_inner = start as usize - 1;
1544        // Start index only moves to geq positions.
1545        start_idx += start_inner;
1546
1547        // Length index only moves to leq positions
1548        let length_inner = (end - start) as usize + 1;
1549        total_length = std::cmp::min(length_inner, total_length - start_inner);
1550    }
1551
1552    let iter = list.iter().skip(start_idx).take(total_length);
1553
1554    temp_storage.make_datum(|row| {
1555        row.push_list_with(|row| {
1556            // if iter is empty, will get the appropriate empty list.
1557            for d in iter {
1558                row.push(d);
1559            }
1560        });
1561    })
1562}
1563
1564#[sqlfunc(sqlname = "timestamp_bin")]
1565fn date_bin_timestamp(
1566    stride: Interval,
1567    source: CheckedTimestamp<NaiveDateTime>,
1568    origin: CheckedTimestamp<NaiveDateTime>,
1569) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
1570    date_bin(stride, source, origin)
1571}
1572
1573#[sqlfunc(sqlname = "timestamptz_bin")]
1574fn date_bin_timestamp_tz(
1575    stride: Interval,
1576    source: CheckedTimestamp<DateTime<Utc>>,
1577    origin: CheckedTimestamp<DateTime<Utc>>,
1578) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
1579    date_bin(stride, source, origin)
1580}
1581
1582#[sqlfunc(sqlname = "timezonet")]
1583fn timezone_time_variadic(
1584    tz_str: &str,
1585    time: NaiveTime,
1586    wall_time: CheckedTimestamp<DateTime<Utc>>,
1587) -> Result<NaiveTime, EvalError> {
1588    parse_timezone(tz_str, TimezoneSpec::Posix)
1589        .map(|tz| timezone_time(tz, time, &wall_time.naive_utc()))
1590}
1591pub(crate) trait LazyVariadicFunc: fmt::Display {
1592    fn eval<'a>(
1593        &'a self,
1594        datums: &[Datum<'a>],
1595        temp_storage: &'a RowArena,
1596        exprs: &'a [MirScalarExpr],
1597    ) -> Result<Datum<'a>, EvalError>;
1598
1599    /// The output SqlColumnType of this function.
1600    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1601
1602    /// Whether this function will produce NULL on NULL input.
1603    fn propagates_nulls(&self) -> bool;
1604
1605    /// Whether this function will produce NULL on non-NULL input.
1606    fn introduces_nulls(&self) -> bool;
1607
1608    /// Whether this function might error on non-error input.
1609    fn could_error(&self) -> bool {
1610        true
1611    }
1612
1613    /// Returns true if the function is monotone.
1614    fn is_monotone(&self) -> bool {
1615        false
1616    }
1617
1618    /// Returns true if the function is associative.
1619    fn is_associative(&self) -> bool {
1620        false
1621    }
1622
1623    /// Returns true if the function is an infix operator.
1624    fn is_infix_op(&self) -> bool {
1625        false
1626    }
1627}
1628
1629pub(crate) trait EagerVariadicFunc: fmt::Display {
1630    type Input<'a>: InputDatumType<'a, EvalError>;
1631    type Output<'a>: OutputDatumType<'a, EvalError>;
1632
1633    fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
1634
1635    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1636
1637    fn propagates_nulls(&self) -> bool {
1638        !Self::Input::nullable()
1639    }
1640
1641    fn introduces_nulls(&self) -> bool {
1642        Self::Output::nullable()
1643    }
1644
1645    fn could_error(&self) -> bool {
1646        Self::Output::fallible()
1647    }
1648
1649    fn is_monotone(&self) -> bool {
1650        false
1651    }
1652
1653    fn is_associative(&self) -> bool {
1654        false
1655    }
1656
1657    fn is_infix_op(&self) -> bool {
1658        false
1659    }
1660}
1661
1662/// Blanket `LazyVariadicFunc` impl for each eager type, bridging
1663/// expression evaluation and null propagation via `InputDatumType::try_from_iter`.
1664impl<T: EagerVariadicFunc> LazyVariadicFunc for T {
1665    fn eval<'a>(
1666        &'a self,
1667        datums: &[Datum<'a>],
1668        temp_storage: &'a RowArena,
1669        exprs: &'a [MirScalarExpr],
1670    ) -> Result<Datum<'a>, EvalError> {
1671        let mut datums = exprs.iter().map(|e| e.eval(datums, temp_storage));
1672        match T::Input::try_from_iter(&mut datums) {
1673            Ok(input) => self.call(input, temp_storage).into_result(temp_storage),
1674            Err(Ok(None)) => Err(EvalError::Internal("missing parameter".into())),
1675            Err(Ok(Some(datum))) if datum.is_null() => Ok(datum),
1676            Err(Ok(Some(_datum))) => {
1677                // datum is _not_ NULL
1678                Err(EvalError::Internal("invalid input type".into()))
1679            }
1680            Err(Err(res)) => Err(res),
1681        }
1682    }
1683
1684    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1685        self.output_type(input_types)
1686    }
1687
1688    fn propagates_nulls(&self) -> bool {
1689        self.propagates_nulls()
1690    }
1691
1692    fn introduces_nulls(&self) -> bool {
1693        self.introduces_nulls()
1694    }
1695
1696    fn could_error(&self) -> bool {
1697        self.could_error()
1698    }
1699
1700    fn is_monotone(&self) -> bool {
1701        self.is_monotone()
1702    }
1703
1704    fn is_associative(&self) -> bool {
1705        self.is_associative()
1706    }
1707
1708    fn is_infix_op(&self) -> bool {
1709        self.is_infix_op()
1710    }
1711}
1712
1713derive_variadic! {
1714    Coalesce(Coalesce),
1715    Greatest(Greatest),
1716    Least(Least),
1717    Concat(Concat),
1718    ConcatWs(ConcatWs),
1719    MakeTimestamp(MakeTimestamp),
1720    PadLeading(PadLeading),
1721    Substr(Substr),
1722    Replace(Replace),
1723    JsonbBuildArray(JsonbBuildArray),
1724    JsonbBuildObject(JsonbBuildObject),
1725    MapBuild(MapBuild),
1726    ArrayCreate(ArrayCreate),
1727    ArrayToString(ArrayToString),
1728    ArrayIndex(ArrayIndex),
1729    ListCreate(ListCreate),
1730    RecordCreate(RecordCreate),
1731    ListIndex(ListIndex),
1732    ListSliceLinear(ListSliceLinear),
1733    SplitPart(SplitPart),
1734    RegexpMatch(RegexpMatch),
1735    HmacString(HmacString),
1736    HmacBytes(HmacBytes),
1737    ErrorIfNull(ErrorIfNull),
1738    DateBinTimestamp(DateBinTimestamp),
1739    DateBinTimestampTz(DateBinTimestampTz),
1740    DateDiffTimestamp(DateDiffTimestamp),
1741    DateDiffTimestampTz(DateDiffTimestampTz),
1742    DateDiffDate(DateDiffDate),
1743    DateDiffTime(DateDiffTime),
1744    And(And),
1745    Or(Or),
1746    RangeCreate(RangeCreate),
1747    MakeAclItem(MakeAclItem),
1748    MakeMzAclItem(MakeMzAclItem),
1749    Translate(Translate),
1750    ArrayPosition(ArrayPosition),
1751    ArrayFill(ArrayFill),
1752    StringToArray(StringToArray),
1753    TimezoneTimeVariadic(TimezoneTimeVariadic),
1754    RegexpSplitToArray(RegexpSplitToArray),
1755    RegexpReplace(RegexpReplace),
1756}
1757
1758impl VariadicFunc {
1759    pub fn switch_and_or(&self) -> Self {
1760        match self {
1761            VariadicFunc::And(_) => Or.into(),
1762            VariadicFunc::Or(_) => And.into(),
1763            _ => unreachable!(),
1764        }
1765    }
1766
1767    /// Gives the unit (u) of OR or AND, such that `u AND/OR x == x`.
1768    /// Note that a 0-arg AND/OR evaluates to unit_of_and_or.
1769    pub fn unit_of_and_or(&self) -> MirScalarExpr {
1770        match self {
1771            VariadicFunc::And(_) => MirScalarExpr::literal_true(),
1772            VariadicFunc::Or(_) => MirScalarExpr::literal_false(),
1773            _ => unreachable!(),
1774        }
1775    }
1776
1777    /// Gives the zero (z) of OR or AND, such that `z AND/OR x == z`.
1778    pub fn zero_of_and_or(&self) -> MirScalarExpr {
1779        match self {
1780            VariadicFunc::And(_) => MirScalarExpr::literal_false(),
1781            VariadicFunc::Or(_) => MirScalarExpr::literal_true(),
1782            _ => unreachable!(),
1783        }
1784    }
1785}