Skip to main content

mz_expr/scalar/func/
variadic.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Variadic functions.
15
16use std::borrow::Cow;
17use std::cmp;
18use std::fmt;
19
20use aws_lc_rs::hmac as aws_hmac;
21use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
22use fallible_iterator::FallibleIterator;
23use hmac::{Hmac, Mac};
24use itertools::Itertools;
25use md5::Md5;
26use mz_expr_derive::sqlfunc;
27use mz_lowertest::MzReflect;
28use mz_ore::cast::{CastFrom, ReinterpretCast};
29use mz_pgtz::timezone::TimezoneSpec;
30use mz_repr::ReprColumnType;
31use mz_repr::adt::array::{Array, ArrayDimension, ArrayDimensions, InvalidArrayError};
32use mz_repr::adt::mz_acl_item::{AclItem, AclMode, MzAclItem};
33use mz_repr::adt::range::{InvalidRangeError, Range, RangeBound, parse_range_bound_flags};
34use mz_repr::adt::system::Oid;
35use mz_repr::adt::timestamp::CheckedTimestamp;
36use mz_repr::role_id::RoleId;
37use mz_repr::{
38    ColumnName, Datum, DatumList, FromDatum, InputDatumType, OptionalArg, OutputDatumType, Row,
39    RowArena, SqlColumnType, SqlScalarType, Variadic,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::func::{
44    CaseLiteral, MAX_STRING_FUNC_RESULT_BYTES, array_create_scalar, build_regex, date_bin,
45    parse_timezone, regexp_match_static, regexp_replace_parse_flags, regexp_split_to_array_re,
46    stringify_datum, timezone_time,
47};
48use crate::{Eval, EvalError, MirScalarExpr};
49use mz_repr::adt::date::Date;
50use mz_repr::adt::interval::Interval;
51use mz_repr::adt::jsonb::JsonbRef;
52
53#[derive(
54    Ord,
55    PartialOrd,
56    Clone,
57    Debug,
58    Eq,
59    PartialEq,
60    Serialize,
61    Deserialize,
62    Hash,
63    MzReflect
64)]
65pub struct And;
66
67impl fmt::Display for And {
68    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69        f.write_str("AND")
70    }
71}
72
73impl LazyVariadicFunc for And {
74    fn eval<'a>(
75        &'a self,
76        datums: &[Datum<'a>],
77        temp_storage: &'a RowArena,
78        exprs: &'a [impl Eval],
79    ) -> Result<Datum<'a>, EvalError> {
80        // If any is false, then return false. Else, if any is null, then return null. Else, return true.
81        let mut null = false;
82        let mut err = None;
83        for expr in exprs {
84            match expr.eval(datums, temp_storage) {
85                Ok(Datum::False) => return Ok(Datum::False), // short-circuit
86                Ok(Datum::True) => {}
87                // No return in these two cases, because we might still see a false
88                Ok(Datum::Null) => null = true,
89                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
90                _ => unreachable!(),
91            }
92        }
93        match (err, null) {
94            (Some(err), _) => Err(err),
95            (None, true) => Ok(Datum::Null),
96            (None, false) => Ok(Datum::True),
97        }
98    }
99
100    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
101        let in_nullable = input_types.iter().any(|t| t.nullable);
102        SqlScalarType::Bool.nullable(in_nullable)
103    }
104
105    fn propagates_nulls(&self) -> bool {
106        false
107    }
108
109    fn introduces_nulls(&self) -> bool {
110        false
111    }
112
113    fn could_error(&self) -> bool {
114        false
115    }
116
117    fn is_monotone(&self) -> bool {
118        true
119    }
120
121    fn is_associative(&self) -> bool {
122        true
123    }
124
125    fn is_infix_op(&self) -> bool {
126        true
127    }
128}
129
130#[derive(
131    Ord,
132    PartialOrd,
133    Clone,
134    Debug,
135    Eq,
136    PartialEq,
137    Serialize,
138    Deserialize,
139    Hash,
140    MzReflect
141)]
142pub struct ArrayCreate {
143    pub elem_type: SqlScalarType,
144}
145
146/// Constructs a new multidimensional array out of an arbitrary number of
147/// lower-dimensional arrays.
148///
149/// For example, if given three 1D arrays of length 2, this function will
150/// construct a 2D array with dimensions 3x2.
151///
152/// The input datums in `datums` must all be arrays of the same dimensions.
153/// (The arrays must also be of the same element type, but that is checked by
154/// the SQL type system, rather than checked here at runtime.)
155///
156/// If all input arrays are zero-dimensional arrays, then the output is a zero-
157/// dimensional array. Otherwise, the lower bound of the additional dimension is
158/// one and the length of the new dimension is equal to `datums.len()`.
159///
160/// Null elements are allowed and considered to be zero-dimensional arrays.
161#[sqlfunc(
162    ArrayCreate,
163    output_type_expr = "match &self.elem_type { SqlScalarType::Array(_) => self.elem_type.clone().nullable(false), _ => SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false) }",
164    introduces_nulls = false
165)]
166fn array_create<'a>(
167    &self,
168    datums: Variadic<Datum<'a>>,
169    temp_storage: &'a RowArena,
170) -> Result<Datum<'a>, EvalError> {
171    match &self.elem_type {
172        SqlScalarType::Array(_) => array_create_multidim(&datums, temp_storage),
173        _ => array_create_scalar(&datums, temp_storage),
174    }
175}
176fn array_create_multidim<'a>(
177    datums: &[Datum<'a>],
178    temp_storage: &'a RowArena,
179) -> Result<Datum<'a>, EvalError> {
180    let mut dim: Option<ArrayDimensions> = None;
181    for datum in datums {
182        let actual_dims = match datum {
183            Datum::Null => ArrayDimensions::default(),
184            Datum::Array(arr) => arr.dims(),
185            d => panic!("unexpected datum {d}"),
186        };
187        if let Some(expected) = &dim {
188            if actual_dims.ndims() != expected.ndims() {
189                let actual = actual_dims.ndims().into();
190                let expected = expected.ndims().into();
191                // All input arrays must have the same dimensionality.
192                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
193            }
194            if let Some((e, a)) = expected
195                .into_iter()
196                .zip_eq(actual_dims)
197                .find(|(e, a)| e != a)
198            {
199                let actual = a.length;
200                let expected = e.length;
201                // All input arrays must have the same dimensionality.
202                return Err(InvalidArrayError::WrongCardinality { actual, expected }.into());
203            }
204        }
205        dim = Some(actual_dims);
206    }
207    // Per PostgreSQL, if all input arrays are zero dimensional, so is the output.
208    if dim.as_ref().map_or(true, ArrayDimensions::is_empty) {
209        return Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&[], &[]))?);
210    }
211
212    let mut dims = vec![ArrayDimension {
213        lower_bound: 1,
214        length: datums.len(),
215    }];
216    if let Some(d) = datums.first() {
217        dims.extend(d.unwrap_array().dims());
218    };
219    let elements = datums
220        .iter()
221        .flat_map(|d| d.unwrap_array().elements().iter());
222    let datum =
223        temp_storage.try_make_datum(move |packer| packer.try_push_array(&dims, elements))?;
224    Ok(datum)
225}
226
227#[derive(
228    Ord,
229    PartialOrd,
230    Clone,
231    Debug,
232    Eq,
233    PartialEq,
234    Serialize,
235    Deserialize,
236    Hash,
237    MzReflect
238)]
239pub struct ArrayFill {
240    pub elem_type: SqlScalarType,
241}
242
243#[sqlfunc(
244    ArrayFill,
245    output_type_expr = "SqlScalarType::Array(Box::new(self.elem_type.clone())).nullable(false)",
246    introduces_nulls = false
247)]
248fn array_fill<'a>(
249    &self,
250    fill: Datum<'a>,
251    dims: Option<Array<'a>>,
252    lower_bounds: OptionalArg<Option<Array<'a>>>,
253    temp_storage: &'a RowArena,
254) -> Result<Datum<'a>, EvalError> {
255    const MAX_SIZE: usize = 1 << 28 - 1;
256    const NULL_ARR_ERR: &str = "dimension array or low bound array";
257    const NULL_ELEM_ERR: &str = "dimension values";
258
259    if matches!(fill, Datum::Array(_)) {
260        return Err(EvalError::Unsupported {
261            feature: "array_fill with arrays".into(),
262            discussion_no: None,
263        });
264    }
265
266    let Some(arr) = dims else {
267        return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
268    };
269
270    let dimensions = arr
271        .elements()
272        .iter()
273        .map(|d| match d {
274            Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
275            d => Ok(usize::cast_from(u32::reinterpret_cast(d.unwrap_int32()))),
276        })
277        .collect::<Result<Vec<_>, _>>()?;
278
279    let lower_bounds = match *lower_bounds {
280        Some(d) => {
281            let Some(arr) = d else {
282                return Err(EvalError::MustNotBeNull(NULL_ARR_ERR.into()));
283            };
284
285            arr.elements()
286                .iter()
287                .map(|l| match l {
288                    Datum::Null => Err(EvalError::MustNotBeNull(NULL_ELEM_ERR.into())),
289                    l => Ok(isize::cast_from(l.unwrap_int32())),
290                })
291                .collect::<Result<Vec<_>, _>>()?
292        }
293        None => {
294            vec![1isize; dimensions.len()]
295        }
296    };
297
298    if lower_bounds.len() != dimensions.len() {
299        return Err(EvalError::ArrayFillWrongArraySubscripts);
300    }
301
302    let fill_count: usize = dimensions
303        .iter()
304        .cloned()
305        .map(Some)
306        .reduce(|a, b| match (a, b) {
307            (Some(a), Some(b)) => a.checked_mul(b),
308            _ => None,
309        })
310        .flatten()
311        .ok_or(EvalError::MaxArraySizeExceeded(MAX_SIZE))?;
312
313    if matches!(
314        mz_repr::datum_size(&fill).checked_mul(fill_count),
315        None | Some(MAX_SIZE..)
316    ) {
317        return Err(EvalError::MaxArraySizeExceeded(MAX_SIZE));
318    }
319
320    let array_dimensions = if fill_count == 0 {
321        vec![ArrayDimension {
322            lower_bound: 1,
323            length: 0,
324        }]
325    } else {
326        dimensions
327            .into_iter()
328            .zip_eq(lower_bounds)
329            .map(|(length, lower_bound)| ArrayDimension {
330                lower_bound,
331                length,
332            })
333            .collect()
334    };
335
336    Ok(temp_storage.try_make_datum(|packer| {
337        packer.try_push_array(&array_dimensions, vec![fill; fill_count])
338    })?)
339}
340
341#[derive(
342    Ord,
343    PartialOrd,
344    Clone,
345    Debug,
346    Eq,
347    PartialEq,
348    Serialize,
349    Deserialize,
350    Hash,
351    MzReflect
352)]
353pub struct ArrayIndex {
354    pub offset: i64,
355}
356#[sqlfunc(ArrayIndex, sqlname = "array_index", introduces_nulls = true)]
357fn array_index<'a, T: FromDatum<'a>>(
358    &self,
359    array: Array<'a, T>,
360    indices: Variadic<i64>,
361) -> Option<T> {
362    mz_ore::soft_assert_no_log!(
363        self.offset == 0 || self.offset == 1,
364        "offset must be either 0 or 1"
365    );
366
367    let dims = array.dims();
368    if dims.len() != indices.len() {
369        // You missed the datums "layer"
370        return None;
371    }
372
373    let mut final_idx = 0;
374
375    for (d, idx) in dims.into_iter().zip_eq(indices.iter()) {
376        // Lower bound is written in terms of 1-based indexing, which offset accounts for.
377        let idx = isize::cast_from(*idx + self.offset);
378
379        let (lower, upper) = d.dimension_bounds();
380
381        // This index missed all of the data at this layer. The dimension bounds are inclusive,
382        // while range checks are exclusive, so adjust.
383        if !(lower..upper + 1).contains(&idx) {
384            return None;
385        }
386
387        // We discover how many indices our last index represents physically.
388        final_idx *= d.length;
389
390        // Because both index and lower bound are handled in 1-based indexing, taking their
391        // difference moves us back into 0-based indexing. Similarly, if the lower bound is
392        // negative, subtracting a negative value >= to itself ensures its non-negativity.
393        final_idx += usize::try_from(idx - d.lower_bound)
394            .expect("previous bounds check ensures physical index is at least 0");
395    }
396
397    array.elements().typed_iter().nth(final_idx)
398}
399
400#[sqlfunc]
401fn array_position<'a>(
402    array: Array<'a>,
403    search: Datum<'a>,
404    initial_pos: OptionalArg<Option<i32>>,
405) -> Result<Option<i32>, EvalError> {
406    if array.dims().len() > 1 {
407        return Err(EvalError::MultiDimensionalArraySearch);
408    }
409
410    if search == Datum::Null {
411        return Ok(None);
412    }
413
414    let skip = match initial_pos.0 {
415        None => 0,
416        Some(None) => return Err(EvalError::MustNotBeNull("initial position".into())),
417        Some(Some(o)) => usize::try_from(o).unwrap_or(0).saturating_sub(1),
418    };
419
420    let Some(r) = array.elements().iter().skip(skip).position(|d| d == search) else {
421        return Ok(None);
422    };
423
424    // Adjust count for the amount we skipped, plus 1 for adjusting to PG indexing scheme.
425    let p = i32::try_from(r + skip + 1).expect("fewer than i32::MAX elements in array");
426    Ok(Some(p))
427}
428
429#[derive(
430    Ord,
431    PartialOrd,
432    Clone,
433    Debug,
434    Eq,
435    PartialEq,
436    Serialize,
437    Deserialize,
438    Hash,
439    MzReflect
440)]
441pub struct ArrayToString {
442    pub elem_type: SqlScalarType,
443}
444
445#[sqlfunc]
446fn array_to_string<'a>(
447    &self,
448    array: Array<'a>,
449    delimiter: &str,
450    null_str_arg: OptionalArg<Option<&str>>,
451) -> Result<String, EvalError> {
452    // `flatten` treats absent arguments (`None`) the same as explicit NULL
453    // (`Some(None)`), both becoming `None`.
454    let null_str = null_str_arg.flatten();
455    let mut out = String::new();
456    for elem in array.elements().iter() {
457        if elem.is_null() {
458            if let Some(null_str) = null_str {
459                out.push_str(null_str);
460                out.push_str(delimiter);
461            }
462        } else {
463            stringify_datum(&mut out, elem, &self.elem_type)?;
464            out.push_str(delimiter);
465        }
466    }
467    if out.len() > 0 {
468        // Lop off last delimiter only if string is not empty
469        out.truncate(out.len() - delimiter.len());
470    }
471    Ok(out)
472}
473
474#[derive(
475    Ord,
476    PartialOrd,
477    Clone,
478    Debug,
479    Eq,
480    PartialEq,
481    Serialize,
482    Deserialize,
483    Hash,
484    MzReflect
485)]
486pub struct Coalesce;
487
488impl fmt::Display for Coalesce {
489    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
490        f.write_str("coalesce")
491    }
492}
493
494impl LazyVariadicFunc for Coalesce {
495    fn eval<'a>(
496        &'a self,
497        datums: &[Datum<'a>],
498        temp_storage: &'a RowArena,
499        exprs: &'a [impl Eval],
500    ) -> Result<Datum<'a>, EvalError> {
501        for e in exprs {
502            let d = e.eval(datums, temp_storage)?;
503            if !d.is_null() {
504                return Ok(d);
505            }
506        }
507        Ok(Datum::Null)
508    }
509
510    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
511        // Note that the parser doesn't allow empty argument lists for variadic functions
512        // that use the standard function call syntax (ArrayCreate and co. are different
513        // because of the special syntax for calling them).
514        let nullable = input_types.iter().all(|typ| typ.nullable);
515        SqlColumnType::union_many(input_types).nullable(nullable)
516    }
517
518    fn propagates_nulls(&self) -> bool {
519        false
520    }
521
522    fn introduces_nulls(&self) -> bool {
523        false
524    }
525
526    fn could_error(&self) -> bool {
527        false
528    }
529
530    fn is_monotone(&self) -> bool {
531        true
532    }
533
534    fn is_associative(&self) -> bool {
535        true
536    }
537}
538
539#[derive(
540    Ord,
541    PartialOrd,
542    Clone,
543    Debug,
544    Eq,
545    PartialEq,
546    Serialize,
547    Deserialize,
548    Hash,
549    MzReflect
550)]
551pub struct RangeCreate {
552    pub elem_type: SqlScalarType,
553}
554
555impl fmt::Display for RangeCreate {
556    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
557        f.write_str(match &self.elem_type {
558            SqlScalarType::Int32 => "int4range",
559            SqlScalarType::Int64 => "int8range",
560            SqlScalarType::Date => "daterange",
561            SqlScalarType::Numeric { .. } => "numrange",
562            SqlScalarType::Timestamp { .. } => "tsrange",
563            SqlScalarType::TimestampTz { .. } => "tstzrange",
564            _ => unreachable!(),
565        })
566    }
567}
568
569impl EagerVariadicFunc for RangeCreate {
570    type Input<'a> = (Datum<'a>, Datum<'a>, Datum<'a>);
571    type Output<'a> = Result<Datum<'a>, EvalError>;
572
573    fn call<'a>(
574        &self,
575        (lower, upper, flags_datum): Self::Input<'a>,
576        temp_storage: &'a RowArena,
577    ) -> Self::Output<'a> {
578        let flags = match flags_datum {
579            Datum::Null => {
580                return Err(EvalError::InvalidRange(
581                    InvalidRangeError::NullRangeBoundFlags,
582                ));
583            }
584            o => o.unwrap_str(),
585        };
586
587        let (lower_inclusive, upper_inclusive) = parse_range_bound_flags(flags)?;
588
589        let mut range = Range::new(Some((
590            RangeBound::new(lower, lower_inclusive),
591            RangeBound::new(upper, upper_inclusive),
592        )));
593
594        range.canonicalize()?;
595
596        Ok(temp_storage.make_datum(|row| {
597            row.push_range(range).expect("errors already handled");
598        }))
599    }
600
601    fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
602        SqlScalarType::Range {
603            element_type: Box::new(self.elem_type.clone()),
604        }
605        .nullable(false)
606    }
607
608    fn introduces_nulls(&self) -> bool {
609        false
610    }
611}
612
613#[sqlfunc(sqlname = "datediff")]
614fn date_diff_date(unit_str: &str, a: Date, b: Date) -> Result<i64, EvalError> {
615    let unit = unit_str
616        .parse()
617        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
618
619    // Convert the Date into a timestamp so we can calculate age.
620    let a_ts = CheckedTimestamp::try_from(NaiveDate::from(a).and_hms_opt(0, 0, 0).unwrap())?;
621    let b_ts = CheckedTimestamp::try_from(NaiveDate::from(b).and_hms_opt(0, 0, 0).unwrap())?;
622    let diff = b_ts.diff_as(&a_ts, unit)?;
623    Ok(diff)
624}
625
626#[sqlfunc(sqlname = "datediff")]
627fn date_diff_time(unit_str: &str, a: NaiveTime, b: NaiveTime) -> Result<i64, EvalError> {
628    let unit = unit_str
629        .parse()
630        .map_err(|_| EvalError::InvalidDatePart(unit_str.into()))?;
631
632    // Convert the Time into a timestamp so we can calculate age.
633    let a_ts =
634        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(a))?;
635    let b_ts =
636        CheckedTimestamp::try_from(NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_time(b))?;
637    let diff = b_ts.diff_as(&a_ts, unit)?;
638    Ok(diff)
639}
640
641#[sqlfunc(sqlname = "datediff")]
642fn date_diff_timestamp(
643    unit: &str,
644    a: CheckedTimestamp<NaiveDateTime>,
645    b: CheckedTimestamp<NaiveDateTime>,
646) -> Result<i64, EvalError> {
647    let unit = unit
648        .parse()
649        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
650
651    let diff = b.diff_as(&a, unit)?;
652    Ok(diff)
653}
654
655#[sqlfunc(sqlname = "datediff")]
656fn date_diff_timestamp_tz(
657    unit: &str,
658    a: CheckedTimestamp<DateTime<Utc>>,
659    b: CheckedTimestamp<DateTime<Utc>>,
660) -> Result<i64, EvalError> {
661    let unit = unit
662        .parse()
663        .map_err(|_| EvalError::InvalidDatePart(unit.into()))?;
664
665    let diff = b.diff_as(&a, unit)?;
666    Ok(diff)
667}
668
669#[derive(
670    Ord,
671    PartialOrd,
672    Clone,
673    Debug,
674    Eq,
675    PartialEq,
676    Serialize,
677    Deserialize,
678    Hash,
679    MzReflect
680)]
681pub struct ErrorIfNull;
682
683impl fmt::Display for ErrorIfNull {
684    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
685        f.write_str("error_if_null")
686    }
687}
688
689impl LazyVariadicFunc for ErrorIfNull {
690    fn eval<'a>(
691        &'a self,
692        datums: &[Datum<'a>],
693        temp_storage: &'a RowArena,
694        exprs: &'a [impl Eval],
695    ) -> Result<Datum<'a>, EvalError> {
696        let first = exprs[0].eval(datums, temp_storage)?;
697        match first {
698            Datum::Null => {
699                let err_msg = match exprs[1].eval(datums, temp_storage)? {
700                    Datum::Null => {
701                        return Err(EvalError::Internal(
702                            "unexpected NULL in error side of error_if_null".into(),
703                        ));
704                    }
705                    o => o.unwrap_str(),
706                };
707                Err(EvalError::IfNullError(err_msg.into()))
708            }
709            _ => Ok(first),
710        }
711    }
712
713    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
714        input_types[0].scalar_type.clone().nullable(false)
715    }
716
717    fn propagates_nulls(&self) -> bool {
718        false
719    }
720
721    fn introduces_nulls(&self) -> bool {
722        false
723    }
724}
725
726#[derive(
727    Ord,
728    PartialOrd,
729    Clone,
730    Debug,
731    Eq,
732    PartialEq,
733    Serialize,
734    Deserialize,
735    Hash,
736    MzReflect
737)]
738pub struct Greatest;
739
740impl fmt::Display for Greatest {
741    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
742        f.write_str("greatest")
743    }
744}
745
746impl LazyVariadicFunc for Greatest {
747    fn eval<'a>(
748        &'a self,
749        datums: &[Datum<'a>],
750        temp_storage: &'a RowArena,
751        exprs: &'a [impl Eval],
752    ) -> Result<Datum<'a>, EvalError> {
753        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
754        Ok(datums
755            .filter(|d| Ok(!d.is_null()))
756            .max()?
757            .unwrap_or(Datum::Null))
758    }
759
760    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
761        SqlColumnType::union_many(input_types)
762    }
763
764    fn propagates_nulls(&self) -> bool {
765        false
766    }
767
768    fn introduces_nulls(&self) -> bool {
769        false
770    }
771
772    fn could_error(&self) -> bool {
773        false
774    }
775
776    fn is_monotone(&self) -> bool {
777        true
778    }
779
780    fn is_associative(&self) -> bool {
781        true
782    }
783}
784
785#[sqlfunc(sqlname = "hmac")]
786fn hmac_string(to_digest: &str, key: &str, typ: &str) -> Result<Vec<u8>, EvalError> {
787    let to_digest = to_digest.as_bytes();
788    let key = key.as_bytes();
789    hmac_inner(to_digest, key, typ)
790}
791
792#[sqlfunc(sqlname = "hmac")]
793fn hmac_bytes(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
794    hmac_inner(to_digest, key, typ)
795}
796
797pub fn hmac_inner(to_digest: &[u8], key: &[u8], typ: &str) -> Result<Vec<u8>, EvalError> {
798    match typ {
799        "md5" => {
800            let mut mac = Hmac::<Md5>::new_from_slice(key).expect("HMAC accepts any key size");
801            mac.update(to_digest);
802            Ok(mac.finalize().into_bytes().to_vec())
803        }
804        "sha1" => {
805            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key);
806            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
807        }
808        "sha224" => {
809            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA224, key);
810            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
811        }
812        "sha256" => {
813            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA256, key);
814            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
815        }
816        "sha384" => {
817            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA384, key);
818            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
819        }
820        "sha512" => {
821            let k = aws_hmac::Key::new(aws_hmac::HMAC_SHA512, key);
822            Ok(aws_hmac::sign(&k, to_digest).as_ref().to_vec())
823        }
824        other => Err(EvalError::InvalidHashAlgorithm(other.into())),
825    }
826}
827
828#[sqlfunc]
829fn jsonb_build_array<'a>(datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> JsonbRef<'a> {
830    let datum = temp_storage.make_datum(|packer| {
831        packer.push_list(datums.into_iter().map(|d| match d {
832            Datum::Null => Datum::JsonNull,
833            d => d,
834        }))
835    });
836    JsonbRef::from_datum(datum)
837}
838
839#[sqlfunc]
840fn jsonb_build_object<'a>(
841    mut kvs: Variadic<(Datum<'a>, Datum<'a>)>,
842    temp_storage: &'a RowArena,
843) -> Result<JsonbRef<'a>, EvalError> {
844    kvs.0.sort_by(|kv1, kv2| kv1.0.cmp(&kv2.0));
845    kvs.0.dedup_by(|kv1, kv2| kv1.0 == kv2.0);
846    let datum = temp_storage.try_make_datum(|packer| {
847        packer.push_dict_with(|packer| {
848            for (k, v) in kvs {
849                if k.is_null() {
850                    return Err(EvalError::KeyCannotBeNull);
851                }
852                let v = match v {
853                    Datum::Null => Datum::JsonNull,
854                    d => d,
855                };
856                packer.push(k);
857                packer.push(v);
858            }
859            Ok(())
860        })
861    })?;
862    Ok(JsonbRef::from_datum(datum))
863}
864
865#[derive(
866    Ord,
867    PartialOrd,
868    Clone,
869    Debug,
870    Eq,
871    PartialEq,
872    Serialize,
873    Deserialize,
874    Hash,
875    MzReflect
876)]
877pub struct Least;
878
879impl fmt::Display for Least {
880    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
881        f.write_str("least")
882    }
883}
884
885impl LazyVariadicFunc for Least {
886    fn eval<'a>(
887        &'a self,
888        datums: &[Datum<'a>],
889        temp_storage: &'a RowArena,
890        exprs: &'a [impl Eval],
891    ) -> Result<Datum<'a>, EvalError> {
892        let datums = fallible_iterator::convert(exprs.iter().map(|e| e.eval(datums, temp_storage)));
893        Ok(datums
894            .filter(|d| Ok(!d.is_null()))
895            .min()?
896            .unwrap_or(Datum::Null))
897    }
898
899    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
900        SqlColumnType::union_many(input_types)
901    }
902
903    fn propagates_nulls(&self) -> bool {
904        false
905    }
906
907    fn introduces_nulls(&self) -> bool {
908        false
909    }
910
911    fn could_error(&self) -> bool {
912        false
913    }
914
915    fn is_monotone(&self) -> bool {
916        true
917    }
918
919    fn is_associative(&self) -> bool {
920        true
921    }
922}
923
924#[derive(
925    Ord,
926    PartialOrd,
927    Clone,
928    Debug,
929    Eq,
930    PartialEq,
931    Serialize,
932    Deserialize,
933    Hash,
934    MzReflect
935)]
936pub struct ListCreate {
937    pub elem_type: SqlScalarType,
938}
939
940#[sqlfunc(
941    output_type_expr = "SqlScalarType::List { element_type: Box::new(self.elem_type.clone()), custom_id: None }.nullable(false)",
942    introduces_nulls = false
943)]
944fn list_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
945    temp_storage.make_datum(|packer| packer.push_list(datums))
946}
947
948#[derive(
949    Ord,
950    PartialOrd,
951    Clone,
952    Debug,
953    Eq,
954    PartialEq,
955    Serialize,
956    Deserialize,
957    Hash,
958    MzReflect
959)]
960pub struct RecordCreate {
961    pub field_names: Vec<ColumnName>,
962}
963
964#[sqlfunc(
965    output_type_expr = "SqlScalarType::Record { fields: self.field_names.clone().into_iter().zip_eq(input_types.iter().cloned()).collect(), custom_id: None }.nullable(false)",
966    introduces_nulls = false
967)]
968fn record_create<'a>(&self, datums: Variadic<Datum<'a>>, temp_storage: &'a RowArena) -> Datum<'a> {
969    temp_storage.make_datum(|packer| packer.push_list(datums.iter().copied()))
970}
971
972#[sqlfunc(
973    output_type_expr = "input_types[0].scalar_type.unwrap_list_nth_layer_type(input_types.len() - 1).clone().nullable(true)",
974    introduces_nulls = true
975)]
976// TODO(benesch): remove potentially dangerous usage of `as`.
977#[allow(clippy::as_conversions)]
978fn list_index<'a>(buf: DatumList<'a>, indices: Variadic<i64>) -> Datum<'a> {
979    let mut buf = Datum::List(buf);
980    for i in indices {
981        if buf.is_null() {
982            break;
983        }
984        if i < 1 {
985            return Datum::Null;
986        }
987
988        buf = match buf.unwrap_list().iter().nth(i as usize - 1) {
989            Some(datum) => datum,
990            None => return Datum::Null,
991        }
992    }
993    buf
994}
995
996#[sqlfunc(sqlname = "makeaclitem")]
997fn make_acl_item(
998    grantee_oid: u32,
999    grantor_oid: u32,
1000    privileges: &str,
1001    is_grantable: bool,
1002) -> Result<AclItem, EvalError> {
1003    let grantee = Oid(grantee_oid);
1004    let grantor = Oid(grantor_oid);
1005    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1006        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1007    if is_grantable {
1008        return Err(EvalError::Unsupported {
1009            feature: "GRANT OPTION".into(),
1010            discussion_no: None,
1011        });
1012    }
1013
1014    Ok(AclItem {
1015        grantee,
1016        grantor,
1017        acl_mode,
1018    })
1019}
1020
1021#[sqlfunc(sqlname = "make_mz_aclitem")]
1022fn make_mz_acl_item(
1023    grantee_str: &str,
1024    grantor_str: &str,
1025    privileges: &str,
1026) -> Result<MzAclItem, EvalError> {
1027    let grantee: RoleId = grantee_str
1028        .parse()
1029        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1030    let grantor: RoleId = grantor_str
1031        .parse()
1032        .map_err(|e: anyhow::Error| EvalError::InvalidRoleId(e.to_string().into()))?;
1033    if grantor == RoleId::Public {
1034        return Err(EvalError::InvalidRoleId(
1035            "mz_aclitem grantor cannot be PUBLIC role".into(),
1036        ));
1037    }
1038    let acl_mode = AclMode::parse_multiple_privileges(privileges)
1039        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
1040
1041    Ok(MzAclItem {
1042        grantee,
1043        grantor,
1044        acl_mode,
1045    })
1046}
1047
1048#[sqlfunc(sqlname = "makets")]
1049// TODO(benesch): remove potentially dangerous usage of `as`.
1050#[allow(clippy::as_conversions)]
1051fn make_timestamp(
1052    year: i64,
1053    month: i64,
1054    day: i64,
1055    hour: i64,
1056    minute: i64,
1057    second_float: f64,
1058) -> Result<Option<CheckedTimestamp<NaiveDateTime>>, EvalError> {
1059    let year: i32 = match year.try_into() {
1060        Ok(year) => year,
1061        Err(_) => return Ok(None),
1062    };
1063    let month: u32 = match month.try_into() {
1064        Ok(month) => month,
1065        Err(_) => return Ok(None),
1066    };
1067    let day: u32 = match day.try_into() {
1068        Ok(day) => day,
1069        Err(_) => return Ok(None),
1070    };
1071    let hour: u32 = match hour.try_into() {
1072        Ok(hour) => hour,
1073        Err(_) => return Ok(None),
1074    };
1075    let minute: u32 = match minute.try_into() {
1076        Ok(minute) => minute,
1077        Err(_) => return Ok(None),
1078    };
1079    let second = second_float as u32;
1080    let micros = ((second_float - second as f64) * 1_000_000.0) as u32;
1081    let date = match NaiveDate::from_ymd_opt(year, month, day) {
1082        Some(date) => date,
1083        None => return Ok(None),
1084    };
1085    let timestamp = match date.and_hms_micro_opt(hour, minute, second, micros) {
1086        Some(timestamp) => timestamp,
1087        None => return Ok(None),
1088    };
1089    Ok(Some(timestamp.try_into()?))
1090}
1091
1092#[derive(
1093    Ord,
1094    PartialOrd,
1095    Clone,
1096    Debug,
1097    Eq,
1098    PartialEq,
1099    Serialize,
1100    Deserialize,
1101    Hash,
1102    MzReflect
1103)]
1104pub struct MapBuild {
1105    pub value_type: SqlScalarType,
1106}
1107
1108#[sqlfunc(
1109    output_type_expr = "SqlScalarType::Map { value_type: Box::new(self.value_type.clone()), custom_id: None }.nullable(false)",
1110    introduces_nulls = false
1111)]
1112fn map_build<'a>(
1113    &self,
1114    datums: Variadic<(Option<&str>, Datum<'a>)>,
1115    temp_storage: &'a RowArena,
1116) -> Datum<'a> {
1117    // Collect into a `BTreeMap` to provide the same semantics as it.
1118    let map: std::collections::BTreeMap<&str, _> = datums
1119        .into_iter()
1120        .filter_map(|(k, v)| k.map(|k| (k, v)))
1121        .collect();
1122
1123    temp_storage.make_datum(|packer| packer.push_dict(map))
1124}
1125
1126#[derive(
1127    Ord,
1128    PartialOrd,
1129    Clone,
1130    Debug,
1131    Eq,
1132    PartialEq,
1133    Serialize,
1134    Deserialize,
1135    Hash,
1136    MzReflect
1137)]
1138pub struct Or;
1139
1140impl fmt::Display for Or {
1141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1142        f.write_str("OR")
1143    }
1144}
1145
1146impl LazyVariadicFunc for Or {
1147    fn eval<'a>(
1148        &'a self,
1149        datums: &[Datum<'a>],
1150        temp_storage: &'a RowArena,
1151        exprs: &'a [impl Eval],
1152    ) -> Result<Datum<'a>, EvalError> {
1153        // If any is true, then return true. Else, if any is null, then return null. Else, return false.
1154        let mut null = false;
1155        let mut err = None;
1156        for expr in exprs {
1157            match expr.eval(datums, temp_storage) {
1158                Ok(Datum::False) => {}
1159                Ok(Datum::True) => return Ok(Datum::True), // short-circuit
1160                // No return in these two cases, because we might still see a true
1161                Ok(Datum::Null) => null = true,
1162                Err(this_err) => err = std::cmp::max(err.take(), Some(this_err)),
1163                _ => unreachable!(),
1164            }
1165        }
1166        match (err, null) {
1167            (Some(err), _) => Err(err),
1168            (None, true) => Ok(Datum::Null),
1169            (None, false) => Ok(Datum::False),
1170        }
1171    }
1172
1173    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1174        let in_nullable = input_types.iter().any(|t| t.nullable);
1175        SqlScalarType::Bool.nullable(in_nullable)
1176    }
1177
1178    fn propagates_nulls(&self) -> bool {
1179        false
1180    }
1181
1182    fn introduces_nulls(&self) -> bool {
1183        false
1184    }
1185
1186    fn could_error(&self) -> bool {
1187        false
1188    }
1189
1190    fn is_monotone(&self) -> bool {
1191        true
1192    }
1193
1194    fn is_associative(&self) -> bool {
1195        true
1196    }
1197
1198    fn is_infix_op(&self) -> bool {
1199        true
1200    }
1201}
1202
1203#[sqlfunc(sqlname = "lpad")]
1204fn pad_leading(string: &str, raw_len: i32, pad: OptionalArg<&str>) -> Result<String, EvalError> {
1205    let len = match usize::try_from(raw_len) {
1206        Ok(len) => len,
1207        Err(_) => {
1208            return Err(EvalError::InvalidParameterValue(
1209                "length must be nonnegative".into(),
1210            ));
1211        }
1212    };
1213    if len > MAX_STRING_FUNC_RESULT_BYTES {
1214        return Err(EvalError::LengthTooLarge);
1215    }
1216
1217    let pad_string = pad.unwrap_or(" ");
1218
1219    let (end_char, end_char_byte_offset) = string
1220        .chars()
1221        .take(len)
1222        .fold((0, 0), |acc, char| (acc.0 + 1, acc.1 + char.len_utf8()));
1223
1224    let mut buf = String::with_capacity(len);
1225    if len == end_char {
1226        buf.push_str(&string[0..end_char_byte_offset]);
1227    } else {
1228        buf.extend(pad_string.chars().cycle().take(len - end_char));
1229        buf.push_str(string);
1230    }
1231
1232    Ok(buf)
1233}
1234
1235#[sqlfunc(
1236    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(true)",
1237    introduces_nulls = true
1238)]
1239fn regexp_match<'a>(
1240    haystack: &'a str,
1241    needle: &str,
1242    flags: OptionalArg<&str>,
1243    temp_storage: &'a RowArena,
1244) -> Result<Datum<'a>, EvalError> {
1245    let flags = flags.unwrap_or("");
1246    let needle = build_regex(needle, flags)?;
1247    regexp_match_static(Datum::String(haystack), temp_storage, &needle)
1248}
1249
1250#[sqlfunc(
1251    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1252    introduces_nulls = false
1253)]
1254fn regexp_split_to_array<'a>(
1255    text: &str,
1256    regexp_str: &str,
1257    flags: OptionalArg<&str>,
1258    temp_storage: &'a RowArena,
1259) -> Result<Datum<'a>, EvalError> {
1260    let flags = flags.unwrap_or("");
1261    let regexp = build_regex(regexp_str, flags)?;
1262    regexp_split_to_array_re(text, &regexp, temp_storage)
1263}
1264
1265#[sqlfunc]
1266fn regexp_replace<'a>(
1267    source: &'a str,
1268    pattern: &str,
1269    replacement: &str,
1270    flags_opt: OptionalArg<&str>,
1271) -> Result<Cow<'a, str>, EvalError> {
1272    let flags = flags_opt.0.unwrap_or("");
1273    let (limit, flags) = regexp_replace_parse_flags(flags);
1274    let regexp = build_regex(pattern, &flags)?;
1275    Ok(regexp.replacen(source, limit, replacement))
1276}
1277
1278#[sqlfunc]
1279fn replace(text: &str, from: &str, to: &str) -> Result<String, EvalError> {
1280    // As a compromise to avoid always nearly duplicating the work of replace by doing size estimation,
1281    // we first check if it's possible for the fully replaced string to exceed the limit by assuming that
1282    // every possible substring is replaced.
1283    //
1284    // If that estimate exceeds the limit, we then do a more precise (and expensive) estimate by counting
1285    // the actual number of replacements that would occur, and using that to calculate the final size.
1286    let possible_size = text.len() * to.len();
1287    if possible_size > MAX_STRING_FUNC_RESULT_BYTES {
1288        let replacement_count = text.matches(from).count();
1289        let estimated_size = text.len() + replacement_count * (to.len().saturating_sub(from.len()));
1290        if estimated_size > MAX_STRING_FUNC_RESULT_BYTES {
1291            return Err(EvalError::LengthTooLarge);
1292        }
1293    }
1294
1295    Ok(text.replace(from, to))
1296}
1297
1298#[sqlfunc(
1299    output_type_expr = "SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)",
1300    introduces_nulls = false,
1301    propagates_nulls = false
1302)]
1303fn string_to_array<'a>(
1304    string: &'a str,
1305    delimiter: Option<&'a str>,
1306    null_string: OptionalArg<Option<&'a str>>,
1307    temp_storage: &'a RowArena,
1308) -> Result<Datum<'a>, EvalError> {
1309    if string.is_empty() {
1310        let mut row = Row::default();
1311        let mut packer = row.packer();
1312        packer.try_push_array(&[], std::iter::empty::<Datum>())?;
1313
1314        return Ok(temp_storage.push_unary_row(row));
1315    }
1316
1317    let Some(delimiter) = delimiter else {
1318        let split_all_chars_delimiter = "";
1319        return string_to_array_impl(
1320            string,
1321            split_all_chars_delimiter,
1322            null_string.flatten(),
1323            temp_storage,
1324        );
1325    };
1326
1327    if delimiter.is_empty() {
1328        let mut row = Row::default();
1329        let mut packer = row.packer();
1330        let dims = &[ArrayDimension {
1331            lower_bound: 1,
1332            length: 1,
1333        }];
1334        match null_string.flatten() {
1335            Some(null_string) if null_string == string => {
1336                packer.try_push_array(dims, std::iter::once(Datum::Null))?;
1337            }
1338            _ => {
1339                packer.try_push_array(dims, vec![string].into_iter().map(Datum::String))?;
1340            }
1341        }
1342        Ok(temp_storage.push_unary_row(row))
1343    } else {
1344        string_to_array_impl(string, delimiter, null_string.flatten(), temp_storage)
1345    }
1346}
1347
1348fn string_to_array_impl<'a>(
1349    string: &str,
1350    delimiter: &str,
1351    null_string: Option<&'a str>,
1352    temp_storage: &'a RowArena,
1353) -> Result<Datum<'a>, EvalError> {
1354    let mut row = Row::default();
1355    let mut packer = row.packer();
1356
1357    let result = string.split(delimiter);
1358    let found: Vec<&str> = if delimiter.is_empty() {
1359        result.filter(|s| !s.is_empty()).collect()
1360    } else {
1361        result.collect()
1362    };
1363    let array_dimensions = [ArrayDimension {
1364        lower_bound: 1,
1365        length: found.len(),
1366    }];
1367
1368    if let Some(null_string) = null_string {
1369        let found_datums = found.into_iter().map(|chunk| {
1370            if chunk.eq(null_string) {
1371                Datum::Null
1372            } else {
1373                Datum::String(chunk)
1374            }
1375        });
1376
1377        packer.try_push_array(&array_dimensions, found_datums)?;
1378    } else {
1379        packer.try_push_array(&array_dimensions, found.into_iter().map(Datum::String))?;
1380    }
1381
1382    Ok(temp_storage.push_unary_row(row))
1383}
1384
1385#[sqlfunc]
1386fn substr<'a>(s: &'a str, start: i32, length: OptionalArg<i32>) -> Result<&'a str, EvalError> {
1387    let raw_start_idx = i64::from(start) - 1;
1388    let start_idx = match usize::try_from(cmp::max(raw_start_idx, 0)) {
1389        Ok(i) => i,
1390        Err(_) => {
1391            return Err(EvalError::InvalidParameterValue(
1392                format!(
1393                    "substring starting index ({}) exceeds min/max position",
1394                    raw_start_idx
1395                )
1396                .into(),
1397            ));
1398        }
1399    };
1400
1401    let mut char_indices = s.char_indices();
1402    let get_str_index = |(index, _char)| index;
1403
1404    let str_len = s.len();
1405    let start_char_idx = char_indices.nth(start_idx).map_or(str_len, get_str_index);
1406
1407    if let OptionalArg(Some(len)) = length {
1408        let end_idx = match i64::from(len) {
1409            e if e < 0 => {
1410                return Err(EvalError::InvalidParameterValue(
1411                    "negative substring length not allowed".into(),
1412                ));
1413            }
1414            e if e == 0 || e + raw_start_idx < 1 => return Ok(""),
1415            e => {
1416                let e = cmp::min(raw_start_idx + e - 1, e - 1);
1417                match usize::try_from(e) {
1418                    Ok(i) => i,
1419                    Err(_) => {
1420                        return Err(EvalError::InvalidParameterValue(
1421                            format!("substring length ({}) exceeds max position", e).into(),
1422                        ));
1423                    }
1424                }
1425            }
1426        };
1427
1428        let end_char_idx = char_indices.nth(end_idx).map_or(str_len, get_str_index);
1429
1430        Ok(&s[start_char_idx..end_char_idx])
1431    } else {
1432        Ok(&s[start_char_idx..])
1433    }
1434}
1435
1436#[sqlfunc(sqlname = "split_string")]
1437fn split_part<'a>(string: &'a str, delimiter: &str, field: i32) -> Result<&'a str, EvalError> {
1438    let index = match usize::try_from(i64::from(field) - 1) {
1439        Ok(index) => index,
1440        Err(_) => {
1441            return Err(EvalError::InvalidParameterValue(
1442                "field position must be greater than zero".into(),
1443            ));
1444        }
1445    };
1446
1447    // If the provided delimiter is the empty string,
1448    // PostgreSQL does not break the string into individual
1449    // characters. Instead, it generates the following parts: [string].
1450    if delimiter.is_empty() {
1451        if index == 0 {
1452            return Ok(string);
1453        } else {
1454            return Ok("");
1455        }
1456    }
1457
1458    // If provided index is greater than the number of split parts,
1459    // return an empty string.
1460    Ok(string.split(delimiter).nth(index).unwrap_or(""))
1461}
1462
1463#[sqlfunc(is_associative = true)]
1464fn concat(strs: Variadic<Option<&str>>) -> Result<String, EvalError> {
1465    let mut total_size = 0;
1466    for s in &strs {
1467        if let Some(s) = s {
1468            total_size += s.len();
1469            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1470                return Err(EvalError::LengthTooLarge);
1471            }
1472        }
1473    }
1474    let mut buf = String::with_capacity(total_size);
1475    for s in strs {
1476        if let Some(s) = s {
1477            buf.push_str(s);
1478        }
1479    }
1480    Ok(buf)
1481}
1482
1483#[sqlfunc]
1484fn concat_ws(ws: &str, rest: Variadic<Option<&str>>) -> Result<String, EvalError> {
1485    let mut total_size = 0;
1486    for s in &rest {
1487        if let Some(s) = s {
1488            total_size += s.len();
1489            total_size += ws.len();
1490            if total_size > MAX_STRING_FUNC_RESULT_BYTES {
1491                return Err(EvalError::LengthTooLarge);
1492            }
1493        }
1494    }
1495
1496    let buf = Itertools::join(&mut rest.into_iter().filter_map(|s| s), ws);
1497
1498    Ok(buf)
1499}
1500
1501#[sqlfunc]
1502fn translate(string: &str, from_str: &str, to_str: &str) -> String {
1503    let from = from_str.chars().collect::<Vec<_>>();
1504    let to = to_str.chars().collect::<Vec<_>>();
1505
1506    string
1507        .chars()
1508        .filter_map(|c| match from.iter().position(|f| f == &c) {
1509            Some(idx) => to.get(idx).copied(),
1510            None => Some(c),
1511        })
1512        .collect()
1513}
1514
1515#[sqlfunc(
1516    output_type_expr = "input_types[0].scalar_type.clone().nullable(false)",
1517    introduces_nulls = false
1518)]
1519// TODO(benesch): remove potentially dangerous usage of `as`.
1520#[allow(clippy::as_conversions)]
1521fn list_slice_linear<'a>(
1522    list: DatumList<'a>,
1523    first: (i64, i64),
1524    remainder: Variadic<(i64, i64)>,
1525    temp_storage: &'a RowArena,
1526) -> Datum<'a> {
1527    let mut start_idx = 0;
1528    let mut total_length = usize::MAX;
1529
1530    for (start, end) in std::iter::once(first).chain(remainder) {
1531        let start = std::cmp::max(start, 1);
1532
1533        // Result should be empty list.
1534        if start > end {
1535            start_idx = 0;
1536            total_length = 0;
1537            break;
1538        }
1539
1540        let start_inner = start as usize - 1;
1541        // Start index only moves to geq positions.
1542        start_idx += start_inner;
1543
1544        // Length index only moves to leq positions
1545        let length_inner = (end - start) as usize + 1;
1546        total_length = std::cmp::min(length_inner, total_length - start_inner);
1547    }
1548
1549    let iter = list.iter().skip(start_idx).take(total_length);
1550
1551    temp_storage.make_datum(|row| {
1552        row.push_list_with(|row| {
1553            // if iter is empty, will get the appropriate empty list.
1554            for d in iter {
1555                row.push(d);
1556            }
1557        });
1558    })
1559}
1560
1561#[sqlfunc(sqlname = "timestamp_bin")]
1562fn date_bin_timestamp(
1563    stride: Interval,
1564    source: CheckedTimestamp<NaiveDateTime>,
1565    origin: CheckedTimestamp<NaiveDateTime>,
1566) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
1567    date_bin(stride, source, origin)
1568}
1569
1570#[sqlfunc(sqlname = "timestamptz_bin")]
1571fn date_bin_timestamp_tz(
1572    stride: Interval,
1573    source: CheckedTimestamp<DateTime<Utc>>,
1574    origin: CheckedTimestamp<DateTime<Utc>>,
1575) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
1576    date_bin(stride, source, origin)
1577}
1578
1579#[sqlfunc(sqlname = "timezonet")]
1580fn timezone_time_variadic(
1581    tz_str: &str,
1582    time: NaiveTime,
1583    wall_time: CheckedTimestamp<DateTime<Utc>>,
1584) -> Result<NaiveTime, EvalError> {
1585    parse_timezone(tz_str, TimezoneSpec::Posix)
1586        .map(|tz| timezone_time(tz, time, &wall_time.naive_utc()))
1587}
1588pub(crate) trait LazyVariadicFunc: fmt::Display {
1589    fn eval<'a>(
1590        &'a self,
1591        datums: &[Datum<'a>],
1592        temp_storage: &'a RowArena,
1593        exprs: &'a [impl Eval],
1594    ) -> Result<Datum<'a>, EvalError>;
1595
1596    /// The output SqlColumnType of this function.
1597    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1598
1599    /// Whether this function will produce NULL on NULL input.
1600    fn propagates_nulls(&self) -> bool;
1601
1602    /// Whether this function will produce NULL on non-NULL input.
1603    fn introduces_nulls(&self) -> bool;
1604
1605    /// Whether this function might error on non-error input.
1606    fn could_error(&self) -> bool {
1607        true
1608    }
1609
1610    /// Returns true if the function is monotone.
1611    fn is_monotone(&self) -> bool {
1612        false
1613    }
1614
1615    /// Returns true if the function is associative.
1616    fn is_associative(&self) -> bool {
1617        false
1618    }
1619
1620    /// Returns true if the function is an infix operator.
1621    fn is_infix_op(&self) -> bool {
1622        false
1623    }
1624}
1625
1626pub(crate) trait EagerVariadicFunc: fmt::Display {
1627    type Input<'a>: InputDatumType<'a, EvalError>;
1628    type Output<'a>: OutputDatumType<'a, EvalError>;
1629
1630    fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
1631
1632    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
1633
1634    fn propagates_nulls(&self) -> bool {
1635        !Self::Input::nullable()
1636    }
1637
1638    fn introduces_nulls(&self) -> bool {
1639        Self::Output::nullable()
1640    }
1641
1642    fn could_error(&self) -> bool {
1643        Self::Output::fallible()
1644    }
1645
1646    fn is_monotone(&self) -> bool {
1647        false
1648    }
1649
1650    fn is_associative(&self) -> bool {
1651        false
1652    }
1653
1654    fn is_infix_op(&self) -> bool {
1655        false
1656    }
1657}
1658
1659/// Blanket `LazyVariadicFunc` impl for each eager type, bridging
1660/// expression evaluation and null propagation via `InputDatumType::try_from_iter`.
1661impl<T: EagerVariadicFunc> LazyVariadicFunc for T {
1662    fn eval<'a>(
1663        &'a self,
1664        datums: &[Datum<'a>],
1665        temp_storage: &'a RowArena,
1666        exprs: &'a [impl Eval],
1667    ) -> Result<Datum<'a>, EvalError> {
1668        let mut datums = exprs.iter().map(|e| e.eval(datums, temp_storage));
1669        match T::Input::try_from_iter(&mut datums) {
1670            Ok(input) => self.call(input, temp_storage).into_result(temp_storage),
1671            Err(Ok(None)) => Err(EvalError::Internal("missing parameter".into())),
1672            Err(Ok(Some(datum))) if datum.is_null() => Ok(datum),
1673            Err(Ok(Some(_datum))) => {
1674                // datum is _not_ NULL
1675                Err(EvalError::Internal("invalid input type".into()))
1676            }
1677            Err(Err(res)) => Err(res),
1678        }
1679    }
1680
1681    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
1682        self.output_type(input_types)
1683    }
1684
1685    fn propagates_nulls(&self) -> bool {
1686        self.propagates_nulls()
1687    }
1688
1689    fn introduces_nulls(&self) -> bool {
1690        self.introduces_nulls()
1691    }
1692
1693    fn could_error(&self) -> bool {
1694        self.could_error()
1695    }
1696
1697    fn is_monotone(&self) -> bool {
1698        self.is_monotone()
1699    }
1700
1701    fn is_associative(&self) -> bool {
1702        self.is_associative()
1703    }
1704
1705    fn is_infix_op(&self) -> bool {
1706        self.is_infix_op()
1707    }
1708}
1709
1710derive_variadic! {
1711    Coalesce(Coalesce),
1712    Greatest(Greatest),
1713    Least(Least),
1714    Concat(Concat),
1715    ConcatWs(ConcatWs),
1716    MakeTimestamp(MakeTimestamp),
1717    PadLeading(PadLeading),
1718    Substr(Substr),
1719    Replace(Replace),
1720    JsonbBuildArray(JsonbBuildArray),
1721    JsonbBuildObject(JsonbBuildObject),
1722    MapBuild(MapBuild),
1723    ArrayCreate(ArrayCreate),
1724    ArrayToString(ArrayToString),
1725    ArrayIndex(ArrayIndex),
1726    ListCreate(ListCreate),
1727    RecordCreate(RecordCreate),
1728    ListIndex(ListIndex),
1729    ListSliceLinear(ListSliceLinear),
1730    SplitPart(SplitPart),
1731    RegexpMatch(RegexpMatch),
1732    HmacString(HmacString),
1733    HmacBytes(HmacBytes),
1734    ErrorIfNull(ErrorIfNull),
1735    DateBinTimestamp(DateBinTimestamp),
1736    DateBinTimestampTz(DateBinTimestampTz),
1737    DateDiffTimestamp(DateDiffTimestamp),
1738    DateDiffTimestampTz(DateDiffTimestampTz),
1739    DateDiffDate(DateDiffDate),
1740    DateDiffTime(DateDiffTime),
1741    And(And),
1742    Or(Or),
1743    RangeCreate(RangeCreate),
1744    MakeAclItem(MakeAclItem),
1745    MakeMzAclItem(MakeMzAclItem),
1746    Translate(Translate),
1747    ArrayPosition(ArrayPosition),
1748    ArrayFill(ArrayFill),
1749    StringToArray(StringToArray),
1750    TimezoneTimeVariadic(TimezoneTimeVariadic),
1751    RegexpSplitToArray(RegexpSplitToArray),
1752    RegexpReplace(RegexpReplace),
1753    CaseLiteral(CaseLiteral),
1754}
1755
1756impl VariadicFunc {
1757    pub fn switch_and_or(&self) -> Self {
1758        match self {
1759            VariadicFunc::And(_) => Or.into(),
1760            VariadicFunc::Or(_) => And.into(),
1761            _ => unreachable!(),
1762        }
1763    }
1764
1765    /// Gives the unit (u) of OR or AND, such that `u AND/OR x == x`.
1766    /// Note that a 0-arg AND/OR evaluates to unit_of_and_or.
1767    pub fn unit_of_and_or(&self) -> MirScalarExpr {
1768        match self {
1769            VariadicFunc::And(_) => MirScalarExpr::literal_true(),
1770            VariadicFunc::Or(_) => MirScalarExpr::literal_false(),
1771            _ => unreachable!(),
1772        }
1773    }
1774
1775    /// Gives the zero (z) of OR or AND, such that `z AND/OR x == z`.
1776    pub fn zero_of_and_or(&self) -> MirScalarExpr {
1777        match self {
1778            VariadicFunc::And(_) => MirScalarExpr::literal_false(),
1779            VariadicFunc::Or(_) => MirScalarExpr::literal_true(),
1780            _ => unreachable!(),
1781        }
1782    }
1783}