Skip to main content

mz_expr/scalar/
func.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14use std::borrow::Cow;
15use std::cmp::Ordering;
16use std::convert::{TryFrom, TryInto};
17use std::str::FromStr;
18use std::{iter, str};
19
20use ::encoding::DecoderTrap;
21use ::encoding::label::encoding_from_whatwg_label;
22use aws_lc_rs::constant_time::verify_slices_are_equal;
23use aws_lc_rs::digest;
24use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, TimeZone, Timelike, Utc};
25use chrono_tz::{OffsetComponents, OffsetName, Tz};
26use dec::OrderedDecimal;
27use itertools::Itertools;
28use md5::{Digest, Md5};
29use mz_expr_derive::sqlfunc;
30use mz_ore::cast::{self, CastFrom};
31use mz_ore::fmt::FormatBuffer;
32use mz_ore::lex::LexBuf;
33use mz_ore::option::OptionExt;
34use mz_pgrepr::Type;
35use mz_pgtz::timezone::{Timezone, TimezoneSpec};
36use mz_repr::adt::array::{Array, ArrayDimension};
37use mz_repr::adt::date::Date;
38use mz_repr::adt::interval::{Interval, RoundBehavior};
39use mz_repr::adt::jsonb::JsonbRef;
40use mz_repr::adt::mz_acl_item::{AclMode, MzAclItem};
41use mz_repr::adt::numeric::{self, Numeric};
42use mz_repr::adt::range::Range;
43use mz_repr::adt::regex::Regex;
44use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
45use mz_repr::{
46    ArrayRustType, Datum, DatumList, DatumMap, ExcludeNull, FromDatum, InputDatumType, Row,
47    RowArena, SqlScalarType, strconv,
48};
49use mz_sql_parser::ast::display::{AstDisplay, FormatMode};
50use mz_sql_pretty::{PrettyConfig, pretty_str};
51use num::traits::CheckedNeg;
52
53use crate::scalar::func::format::DateTimeFormat;
54use crate::{EvalError, like_pattern};
55
56#[macro_use]
57mod macros;
58mod binary;
59mod encoding;
60pub(crate) mod format;
61pub(crate) mod impls;
62mod unary;
63mod unmaterializable;
64pub mod variadic;
65
66pub use binary::BinaryFunc;
67pub use impls::*;
68pub use unary::{EagerUnaryFunc, LazyUnaryFunc, UnaryFunc};
69pub use unmaterializable::UnmaterializableFunc;
70pub use variadic::VariadicFunc;
71
72/// The maximum size of the result strings of certain string functions, such as `repeat` and `lpad`.
73/// Chosen to be the smallest number to keep our tests passing without changing. 100MiB is probably
74/// higher than what we want, but it's better than no limit.
75///
76/// Note: This number appears in our user-facing documentation in the function reference for every
77/// function where it applies.
78pub const MAX_STRING_FUNC_RESULT_BYTES: usize = 1024 * 1024 * 100;
79
80pub fn jsonb_stringify<'a>(a: Datum<'a>, temp_storage: &'a RowArena) -> Option<&'a str> {
81    match a {
82        Datum::JsonNull => None,
83        Datum::String(s) => Some(s),
84        _ => {
85            let s = cast_jsonb_to_string(JsonbRef::from_datum(a));
86            Some(temp_storage.push_string(s))
87        }
88    }
89}
90
91#[sqlfunc(
92    is_monotone = "(true, true)",
93    is_infix_op = true,
94    sqlname = "+",
95    propagates_nulls = true
96)]
97fn add_int16(a: i16, b: i16) -> Result<i16, EvalError> {
98    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
99}
100
101#[sqlfunc(
102    is_monotone = "(true, true)",
103    is_infix_op = true,
104    sqlname = "+",
105    propagates_nulls = true
106)]
107fn add_int32(a: i32, b: i32) -> Result<i32, EvalError> {
108    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
109}
110
111#[sqlfunc(
112    is_monotone = "(true, true)",
113    is_infix_op = true,
114    sqlname = "+",
115    propagates_nulls = true
116)]
117fn add_int64(a: i64, b: i64) -> Result<i64, EvalError> {
118    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
119}
120
121#[sqlfunc(
122    is_monotone = "(true, true)",
123    is_infix_op = true,
124    sqlname = "+",
125    propagates_nulls = true
126)]
127fn add_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
128    a.checked_add(b)
129        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} + {b}").into()))
130}
131
132#[sqlfunc(
133    is_monotone = "(true, true)",
134    is_infix_op = true,
135    sqlname = "+",
136    propagates_nulls = true
137)]
138fn add_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
139    a.checked_add(b)
140        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} + {b}").into()))
141}
142
143#[sqlfunc(
144    is_monotone = "(true, true)",
145    is_infix_op = true,
146    sqlname = "+",
147    propagates_nulls = true
148)]
149fn add_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
150    a.checked_add(b)
151        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} + {b}").into()))
152}
153
154#[sqlfunc(
155    is_monotone = "(true, true)",
156    is_infix_op = true,
157    sqlname = "+",
158    propagates_nulls = true
159)]
160fn add_float32(a: f32, b: f32) -> Result<f32, EvalError> {
161    let sum = a + b;
162    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
163        Err(EvalError::FloatOverflow)
164    } else {
165        Ok(sum)
166    }
167}
168
169#[sqlfunc(
170    is_monotone = "(true, true)",
171    is_infix_op = true,
172    sqlname = "+",
173    propagates_nulls = true
174)]
175fn add_float64(a: f64, b: f64) -> Result<f64, EvalError> {
176    let sum = a + b;
177    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
178        Err(EvalError::FloatOverflow)
179    } else {
180        Ok(sum)
181    }
182}
183
184#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
185fn add_timestamp_interval(
186    a: CheckedTimestamp<NaiveDateTime>,
187    b: Interval,
188) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
189    add_timestamplike_interval(a, b)
190}
191
192#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
193fn add_timestamp_tz_interval(
194    a: CheckedTimestamp<DateTime<Utc>>,
195    b: Interval,
196) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
197    add_timestamplike_interval(a, b)
198}
199
200fn add_timestamplike_interval<T>(
201    a: CheckedTimestamp<T>,
202    b: Interval,
203) -> Result<CheckedTimestamp<T>, EvalError>
204where
205    T: TimestampLike,
206{
207    let dt = a.date_time();
208    let dt = add_timestamp_months(&dt, b.months)?;
209    let dt = dt
210        .checked_add_signed(b.duration_as_chrono())
211        .ok_or(EvalError::TimestampOutOfRange)?;
212    Ok(CheckedTimestamp::from_timestamplike(T::from_date_time(dt))?)
213}
214
215#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
216fn sub_timestamp_interval(
217    a: CheckedTimestamp<NaiveDateTime>,
218    b: Interval,
219) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
220    sub_timestamplike_interval(a, b)
221}
222
223#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
224fn sub_timestamp_tz_interval(
225    a: CheckedTimestamp<DateTime<Utc>>,
226    b: Interval,
227) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
228    sub_timestamplike_interval(a, b)
229}
230
231fn sub_timestamplike_interval<T>(
232    a: CheckedTimestamp<T>,
233    b: Interval,
234) -> Result<CheckedTimestamp<T>, EvalError>
235where
236    T: TimestampLike,
237{
238    neg_interval_inner(b).and_then(|i| add_timestamplike_interval(a, i))
239}
240
241#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
242fn add_date_time(
243    date: Date,
244    time: chrono::NaiveTime,
245) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
246    let dt = NaiveDate::from(date)
247        .and_hms_nano_opt(time.hour(), time.minute(), time.second(), time.nanosecond())
248        .unwrap();
249    Ok(CheckedTimestamp::from_timestamplike(dt)?)
250}
251
252#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
253fn add_date_interval(
254    date: Date,
255    interval: Interval,
256) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
257    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
258    let dt = add_timestamp_months(&dt, interval.months)?;
259    let dt = dt
260        .checked_add_signed(interval.duration_as_chrono())
261        .ok_or(EvalError::TimestampOutOfRange)?;
262    Ok(CheckedTimestamp::from_timestamplike(dt)?)
263}
264
265#[sqlfunc(
266    // <time> + <interval> wraps!
267    is_monotone = "(false, false)",
268    is_infix_op = true,
269    sqlname = "+",
270    propagates_nulls = true
271)]
272fn add_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
273    let (t, _) = time.overflowing_add_signed(interval.duration_as_chrono());
274    t
275}
276
277#[sqlfunc(
278    is_monotone = "(true, false)",
279    output_type = "Numeric",
280    sqlname = "round",
281    propagates_nulls = true
282)]
283fn round_numeric_binary(a: OrderedDecimal<Numeric>, mut b: i32) -> Result<Numeric, EvalError> {
284    let mut a = a.0;
285    let mut cx = numeric::cx_datum();
286    let a_exp = a.exponent();
287    if a_exp > 0 && b > 0 || a_exp < 0 && -a_exp < b {
288        // This condition indicates:
289        // - a is a value without a decimal point, b is a positive number
290        // - a has a decimal point, but b is larger than its scale
291        // In both of these situations, right-pad the number with zeroes, which // is most easily done with rescale.
292
293        // Ensure rescale doesn't exceed max precision by putting a ceiling on
294        // b equal to the maximum remaining scale the value can support.
295        let max_remaining_scale = u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
296            - (numeric::get_precision(&a) - numeric::get_scale(&a));
297        b = match i32::try_from(max_remaining_scale) {
298            Ok(max_remaining_scale) => std::cmp::min(b, max_remaining_scale),
299            Err(_) => b,
300        };
301        cx.rescale(&mut a, &numeric::Numeric::from(-b));
302    } else {
303        // To avoid invalid operations, clamp b to be within 1 more than the
304        // precision limit.
305        const MAX_P_LIMIT: i32 = 1 + cast::u8_to_i32(numeric::NUMERIC_DATUM_MAX_PRECISION);
306        b = std::cmp::min(MAX_P_LIMIT, b);
307        b = std::cmp::max(-MAX_P_LIMIT, b);
308        let mut b = numeric::Numeric::from(b);
309        // Shift by 10^b; this put digit to round to in the one's place.
310        cx.scaleb(&mut a, &b);
311        cx.round(&mut a);
312        // Negate exponent for shift back
313        cx.neg(&mut b);
314        cx.scaleb(&mut a, &b);
315    }
316
317    if cx.status().overflow() {
318        Err(EvalError::FloatOverflow)
319    } else if a.is_zero() {
320        // simpler than handling cases where exponent has gotten set to some
321        // value greater than the max precision, but all significant digits
322        // were rounded away.
323        Ok(numeric::Numeric::zero())
324    } else {
325        numeric::munge_numeric(&mut a).unwrap();
326        Ok(a)
327    }
328}
329
330#[sqlfunc(sqlname = "convert_from", propagates_nulls = true)]
331fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
332    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
333    // which the encoding library uses[3].
334    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
335    // [2]: https://encoding.spec.whatwg.org/
336    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
337    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
338
339    // Supporting other encodings is tracked by database-issues#797.
340    if encoding_from_whatwg_label(&encoding_name).map(|e| e.name()) != Some("utf-8") {
341        return Err(EvalError::InvalidEncodingName(encoding_name));
342    }
343
344    match str::from_utf8(a) {
345        Ok(from) => Ok(from),
346        Err(e) => Err(EvalError::InvalidByteSequence {
347            byte_sequence: e.to_string().into(),
348            encoding_name,
349        }),
350    }
351}
352
353#[sqlfunc]
354fn encode(bytes: &[u8], format: &str) -> Result<String, EvalError> {
355    let format = encoding::lookup_format(format)?;
356    Ok(format.encode(bytes))
357}
358
359#[sqlfunc]
360fn decode(string: &str, format: &str) -> Result<Vec<u8>, EvalError> {
361    let format = encoding::lookup_format(format)?;
362    let out = format.decode(string)?;
363    if out.len() > MAX_STRING_FUNC_RESULT_BYTES {
364        Err(EvalError::LengthTooLarge)
365    } else {
366        Ok(out)
367    }
368}
369
370#[sqlfunc(sqlname = "length", propagates_nulls = true)]
371fn encoded_bytes_char_length(a: &[u8], b: &str) -> Result<i32, EvalError> {
372    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
373    // which the encoding library uses[3].
374    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
375    // [2]: https://encoding.spec.whatwg.org/
376    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
377    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
378
379    let enc = match encoding_from_whatwg_label(&encoding_name) {
380        Some(enc) => enc,
381        None => return Err(EvalError::InvalidEncodingName(encoding_name)),
382    };
383
384    let decoded_string = match enc.decode(a, DecoderTrap::Strict) {
385        Ok(s) => s,
386        Err(e) => {
387            return Err(EvalError::InvalidByteSequence {
388                byte_sequence: e.into(),
389                encoding_name,
390            });
391        }
392    };
393
394    let count = decoded_string.chars().count();
395    i32::try_from(count).map_err(|_| EvalError::Int32OutOfRange(count.to_string().into()))
396}
397
398// TODO(benesch): remove potentially dangerous usage of `as`.
399#[allow(clippy::as_conversions)]
400pub fn add_timestamp_months<T: TimestampLike>(
401    dt: &T,
402    mut months: i32,
403) -> Result<CheckedTimestamp<T>, EvalError> {
404    if months == 0 {
405        return Ok(CheckedTimestamp::from_timestamplike(dt.clone())?);
406    }
407
408    let (mut year, mut month, mut day) = (dt.year(), dt.month0() as i32, dt.day());
409    let years = months / 12;
410    year = year
411        .checked_add(years)
412        .ok_or(EvalError::TimestampOutOfRange)?;
413
414    months %= 12;
415    // positive modulus is easier to reason about
416    if months < 0 {
417        year -= 1;
418        months += 12;
419    }
420    year += (month + months) / 12;
421    month = (month + months) % 12;
422    // account for dt.month0
423    month += 1;
424
425    // handle going from January 31st to February by saturation
426    let mut new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
427    while new_d.is_none() {
428        // If we have decremented day past 28 and are still receiving `None`,
429        // then we have generally overflowed `NaiveDate`.
430        if day < 28 {
431            return Err(EvalError::TimestampOutOfRange);
432        }
433        day -= 1;
434        new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
435    }
436    let new_d = new_d.unwrap();
437
438    // Neither postgres nor mysql support leap seconds, so this should be safe.
439    //
440    // Both my testing and https://dba.stackexchange.com/a/105829 support the
441    // idea that we should ignore leap seconds
442    let new_dt = new_d
443        .and_hms_nano_opt(dt.hour(), dt.minute(), dt.second(), dt.nanosecond())
444        .unwrap();
445    let new_dt = T::from_date_time(new_dt);
446    Ok(CheckedTimestamp::from_timestamplike(new_dt)?)
447}
448
449#[sqlfunc(
450    is_monotone = "(true, true)",
451    is_infix_op = true,
452    sqlname = "+",
453    propagates_nulls = true
454)]
455fn add_numeric(
456    a: OrderedDecimal<Numeric>,
457    b: OrderedDecimal<Numeric>,
458) -> Result<Numeric, EvalError> {
459    let mut cx = numeric::cx_datum();
460    let mut a = a.0;
461    cx.add(&mut a, &b.0);
462    if cx.status().overflow() {
463        Err(EvalError::FloatOverflow)
464    } else {
465        Ok(a)
466    }
467}
468
469#[sqlfunc(
470    is_monotone = "(true, true)",
471    is_infix_op = true,
472    sqlname = "+",
473    propagates_nulls = true
474)]
475fn add_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
476    a.checked_add(&b)
477        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} + {b}").into()))
478}
479
480#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
481fn bit_and_int16(a: i16, b: i16) -> i16 {
482    a & b
483}
484
485#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
486fn bit_and_int32(a: i32, b: i32) -> i32 {
487    a & b
488}
489
490#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
491fn bit_and_int64(a: i64, b: i64) -> i64 {
492    a & b
493}
494
495#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
496fn bit_and_uint16(a: u16, b: u16) -> u16 {
497    a & b
498}
499
500#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
501fn bit_and_uint32(a: u32, b: u32) -> u32 {
502    a & b
503}
504
505#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
506fn bit_and_uint64(a: u64, b: u64) -> u64 {
507    a & b
508}
509
510#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
511fn bit_or_int16(a: i16, b: i16) -> i16 {
512    a | b
513}
514
515#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
516fn bit_or_int32(a: i32, b: i32) -> i32 {
517    a | b
518}
519
520#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
521fn bit_or_int64(a: i64, b: i64) -> i64 {
522    a | b
523}
524
525#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
526fn bit_or_uint16(a: u16, b: u16) -> u16 {
527    a | b
528}
529
530#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
531fn bit_or_uint32(a: u32, b: u32) -> u32 {
532    a | b
533}
534
535#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
536fn bit_or_uint64(a: u64, b: u64) -> u64 {
537    a | b
538}
539
540#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
541fn bit_xor_int16(a: i16, b: i16) -> i16 {
542    a ^ b
543}
544
545#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
546fn bit_xor_int32(a: i32, b: i32) -> i32 {
547    a ^ b
548}
549
550#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
551fn bit_xor_int64(a: i64, b: i64) -> i64 {
552    a ^ b
553}
554
555#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
556fn bit_xor_uint16(a: u16, b: u16) -> u16 {
557    a ^ b
558}
559
560#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
561fn bit_xor_uint32(a: u32, b: u32) -> u32 {
562    a ^ b
563}
564
565#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
566fn bit_xor_uint64(a: u64, b: u64) -> u64 {
567    a ^ b
568}
569
570#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
571// TODO(benesch): remove potentially dangerous usage of `as`.
572#[allow(clippy::as_conversions)]
573fn bit_shift_left_int16(a: i16, b: i32) -> i16 {
574    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
575    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
576    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
577    let lhs: i32 = a as i32;
578    let rhs: u32 = b as u32;
579    lhs.wrapping_shl(rhs) as i16
580}
581
582#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
583// TODO(benesch): remove potentially dangerous usage of `as`.
584#[allow(clippy::as_conversions)]
585fn bit_shift_left_int32(lhs: i32, rhs: i32) -> i32 {
586    let rhs = rhs as u32;
587    lhs.wrapping_shl(rhs)
588}
589
590#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
591// TODO(benesch): remove potentially dangerous usage of `as`.
592#[allow(clippy::as_conversions)]
593fn bit_shift_left_int64(lhs: i64, rhs: i32) -> i64 {
594    let rhs = rhs as u32;
595    lhs.wrapping_shl(rhs)
596}
597
598#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
599// TODO(benesch): remove potentially dangerous usage of `as`.
600#[allow(clippy::as_conversions)]
601fn bit_shift_left_uint16(a: u16, b: u32) -> u16 {
602    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
603    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
604    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
605    let lhs: u32 = a as u32;
606    let rhs: u32 = b;
607    lhs.wrapping_shl(rhs) as u16
608}
609
610#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
611fn bit_shift_left_uint32(a: u32, b: u32) -> u32 {
612    let lhs = a;
613    let rhs = b;
614    lhs.wrapping_shl(rhs)
615}
616
617#[sqlfunc(
618    output_type = "u64",
619    is_infix_op = true,
620    sqlname = "<<",
621    propagates_nulls = true
622)]
623fn bit_shift_left_uint64(lhs: u64, rhs: u32) -> u64 {
624    lhs.wrapping_shl(rhs)
625}
626
627#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
628// TODO(benesch): remove potentially dangerous usage of `as`.
629#[allow(clippy::as_conversions)]
630fn bit_shift_right_int16(lhs: i16, rhs: i32) -> i16 {
631    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
632    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
633    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
634    let lhs = lhs as i32;
635    let rhs = rhs as u32;
636    lhs.wrapping_shr(rhs) as i16
637}
638
639#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
640// TODO(benesch): remove potentially dangerous usage of `as`.
641#[allow(clippy::as_conversions)]
642fn bit_shift_right_int32(lhs: i32, rhs: i32) -> i32 {
643    lhs.wrapping_shr(rhs as u32)
644}
645
646#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
647// TODO(benesch): remove potentially dangerous usage of `as`.
648#[allow(clippy::as_conversions)]
649fn bit_shift_right_int64(lhs: i64, rhs: i32) -> i64 {
650    lhs.wrapping_shr(rhs as u32)
651}
652
653#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
654// TODO(benesch): remove potentially dangerous usage of `as`.
655#[allow(clippy::as_conversions)]
656fn bit_shift_right_uint16(lhs: u16, rhs: u32) -> u16 {
657    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
658    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
659    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
660    let lhs = lhs as u32;
661    lhs.wrapping_shr(rhs) as u16
662}
663
664#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
665fn bit_shift_right_uint32(lhs: u32, rhs: u32) -> u32 {
666    lhs.wrapping_shr(rhs)
667}
668
669#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
670fn bit_shift_right_uint64(lhs: u64, rhs: u32) -> u64 {
671    lhs.wrapping_shr(rhs)
672}
673
674#[sqlfunc(
675    is_monotone = "(true, true)",
676    is_infix_op = true,
677    sqlname = "-",
678    propagates_nulls = true
679)]
680fn sub_int16(a: i16, b: i16) -> Result<i16, EvalError> {
681    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
682}
683
684#[sqlfunc(
685    is_monotone = "(true, true)",
686    is_infix_op = true,
687    sqlname = "-",
688    propagates_nulls = true
689)]
690fn sub_int32(a: i32, b: i32) -> Result<i32, EvalError> {
691    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
692}
693
694#[sqlfunc(
695    is_monotone = "(true, true)",
696    is_infix_op = true,
697    sqlname = "-",
698    propagates_nulls = true
699)]
700fn sub_int64(a: i64, b: i64) -> Result<i64, EvalError> {
701    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
702}
703
704#[sqlfunc(
705    is_monotone = "(true, true)",
706    is_infix_op = true,
707    sqlname = "-",
708    propagates_nulls = true
709)]
710fn sub_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
711    a.checked_sub(b)
712        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} - {b}").into()))
713}
714
715#[sqlfunc(
716    is_monotone = "(true, true)",
717    is_infix_op = true,
718    sqlname = "-",
719    propagates_nulls = true
720)]
721fn sub_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
722    a.checked_sub(b)
723        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} - {b}").into()))
724}
725
726#[sqlfunc(
727    is_monotone = "(true, true)",
728    is_infix_op = true,
729    sqlname = "-",
730    propagates_nulls = true
731)]
732fn sub_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
733    a.checked_sub(b)
734        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} - {b}").into()))
735}
736
737#[sqlfunc(
738    is_monotone = "(true, true)",
739    is_infix_op = true,
740    sqlname = "-",
741    propagates_nulls = true
742)]
743fn sub_float32(a: f32, b: f32) -> Result<f32, EvalError> {
744    let difference = a - b;
745    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
746        Err(EvalError::FloatOverflow)
747    } else {
748        Ok(difference)
749    }
750}
751
752#[sqlfunc(
753    is_monotone = "(true, true)",
754    is_infix_op = true,
755    sqlname = "-",
756    propagates_nulls = true
757)]
758fn sub_float64(a: f64, b: f64) -> Result<f64, EvalError> {
759    let difference = a - b;
760    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
761        Err(EvalError::FloatOverflow)
762    } else {
763        Ok(difference)
764    }
765}
766
767#[sqlfunc(
768    is_monotone = "(true, true)",
769    is_infix_op = true,
770    sqlname = "-",
771    propagates_nulls = true
772)]
773fn sub_numeric(
774    a: OrderedDecimal<Numeric>,
775    b: OrderedDecimal<Numeric>,
776) -> Result<Numeric, EvalError> {
777    let mut cx = numeric::cx_datum();
778    let mut a = a.0;
779    cx.sub(&mut a, &b.0);
780    if cx.status().overflow() {
781        Err(EvalError::FloatOverflow)
782    } else {
783        Ok(a)
784    }
785}
786
787#[sqlfunc(
788    is_monotone = "(true, true)",
789    output_type = "Interval",
790    sqlname = "age",
791    propagates_nulls = true
792)]
793fn age_timestamp(
794    a: CheckedTimestamp<chrono::NaiveDateTime>,
795    b: CheckedTimestamp<chrono::NaiveDateTime>,
796) -> Result<Interval, EvalError> {
797    Ok(a.age(&b)?)
798}
799
800#[sqlfunc(is_monotone = "(true, true)", sqlname = "age", propagates_nulls = true)]
801fn age_timestamp_tz(
802    a: CheckedTimestamp<chrono::DateTime<Utc>>,
803    b: CheckedTimestamp<chrono::DateTime<Utc>>,
804) -> Result<Interval, EvalError> {
805    Ok(a.age(&b)?)
806}
807
808#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
809fn sub_timestamp(
810    a: CheckedTimestamp<NaiveDateTime>,
811    b: CheckedTimestamp<NaiveDateTime>,
812) -> Result<Interval, EvalError> {
813    Interval::from_chrono_duration(a - b)
814        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
815}
816
817#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
818fn sub_timestamp_tz(
819    a: CheckedTimestamp<chrono::DateTime<Utc>>,
820    b: CheckedTimestamp<chrono::DateTime<Utc>>,
821) -> Result<Interval, EvalError> {
822    Interval::from_chrono_duration(a - b)
823        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
824}
825
826#[sqlfunc(
827    is_monotone = "(true, true)",
828    is_infix_op = true,
829    sqlname = "-",
830    propagates_nulls = true
831)]
832fn sub_date(a: Date, b: Date) -> i32 {
833    a - b
834}
835
836#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
837fn sub_time(a: chrono::NaiveTime, b: chrono::NaiveTime) -> Result<Interval, EvalError> {
838    Interval::from_chrono_duration(a - b)
839        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
840}
841
842#[sqlfunc(
843    is_monotone = "(true, true)",
844    output_type = "Interval",
845    is_infix_op = true,
846    sqlname = "-",
847    propagates_nulls = true
848)]
849fn sub_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
850    b.checked_neg()
851        .and_then(|b| b.checked_add(&a))
852        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} - {b}").into()))
853}
854
855#[sqlfunc(
856    is_monotone = "(true, true)",
857    is_infix_op = true,
858    sqlname = "-",
859    propagates_nulls = true
860)]
861fn sub_date_interval(
862    date: Date,
863    interval: Interval,
864) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
865    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
866    let dt = interval
867        .months
868        .checked_neg()
869        .ok_or_else(|| EvalError::IntervalOutOfRange(interval.months.to_string().into()))
870        .and_then(|months| add_timestamp_months(&dt, months))?;
871    let dt = dt
872        .checked_sub_signed(interval.duration_as_chrono())
873        .ok_or(EvalError::TimestampOutOfRange)?;
874    Ok(dt.try_into()?)
875}
876
877#[sqlfunc(
878    is_monotone = "(false, false)",
879    is_infix_op = true,
880    sqlname = "-",
881    propagates_nulls = true
882)]
883fn sub_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
884    let (t, _) = time.overflowing_sub_signed(interval.duration_as_chrono());
885    t
886}
887
888#[sqlfunc(
889    is_monotone = "(true, true)",
890    is_infix_op = true,
891    sqlname = "*",
892    propagates_nulls = true
893)]
894fn mul_int16(a: i16, b: i16) -> Result<i16, EvalError> {
895    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
896}
897
898#[sqlfunc(
899    is_monotone = "(true, true)",
900    is_infix_op = true,
901    sqlname = "*",
902    propagates_nulls = true
903)]
904fn mul_int32(a: i32, b: i32) -> Result<i32, EvalError> {
905    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
906}
907
908#[sqlfunc(
909    is_monotone = "(true, true)",
910    is_infix_op = true,
911    sqlname = "*",
912    propagates_nulls = true
913)]
914fn mul_int64(a: i64, b: i64) -> Result<i64, EvalError> {
915    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
916}
917
918#[sqlfunc(
919    is_monotone = "(true, true)",
920    is_infix_op = true,
921    sqlname = "*",
922    propagates_nulls = true
923)]
924fn mul_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
925    a.checked_mul(b)
926        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} * {b}").into()))
927}
928
929#[sqlfunc(
930    is_monotone = "(true, true)",
931    is_infix_op = true,
932    sqlname = "*",
933    propagates_nulls = true
934)]
935fn mul_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
936    a.checked_mul(b)
937        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} * {b}").into()))
938}
939
940#[sqlfunc(
941    is_monotone = "(true, true)",
942    is_infix_op = true,
943    sqlname = "*",
944    propagates_nulls = true
945)]
946fn mul_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
947    a.checked_mul(b)
948        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} * {b}").into()))
949}
950
951#[sqlfunc(
952    is_monotone = (true, true),
953    is_infix_op = true,
954    sqlname = "*",
955    propagates_nulls = true
956)]
957fn mul_float32(a: f32, b: f32) -> Result<f32, EvalError> {
958    let product = a * b;
959    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
960        Err(EvalError::FloatOverflow)
961    } else if product == 0.0f32 && a != 0.0f32 && b != 0.0f32 {
962        Err(EvalError::FloatUnderflow)
963    } else {
964        Ok(product)
965    }
966}
967
968#[sqlfunc(
969    is_monotone = "(true, true)",
970    is_infix_op = true,
971    sqlname = "*",
972    propagates_nulls = true
973)]
974fn mul_float64(a: f64, b: f64) -> Result<f64, EvalError> {
975    let product = a * b;
976    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
977        Err(EvalError::FloatOverflow)
978    } else if product == 0.0f64 && a != 0.0f64 && b != 0.0f64 {
979        Err(EvalError::FloatUnderflow)
980    } else {
981        Ok(product)
982    }
983}
984
985#[sqlfunc(
986    is_monotone = "(true, true)",
987    is_infix_op = true,
988    sqlname = "*",
989    propagates_nulls = true
990)]
991fn mul_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
992    let mut cx = numeric::cx_datum();
993    cx.mul(&mut a, &b);
994    let cx_status = cx.status();
995    if cx_status.overflow() {
996        Err(EvalError::FloatOverflow)
997    } else if cx_status.subnormal() {
998        Err(EvalError::FloatUnderflow)
999    } else {
1000        numeric::munge_numeric(&mut a).unwrap();
1001        Ok(a)
1002    }
1003}
1004
1005#[sqlfunc(
1006    is_monotone = "(false, false)",
1007    is_infix_op = true,
1008    sqlname = "*",
1009    propagates_nulls = true
1010)]
1011fn mul_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1012    a.checked_mul(b)
1013        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} * {b}").into()))
1014}
1015
1016#[sqlfunc(
1017    is_monotone = "(true, false)",
1018    is_infix_op = true,
1019    sqlname = "/",
1020    propagates_nulls = true
1021)]
1022fn div_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1023    if b == 0 {
1024        Err(EvalError::DivisionByZero)
1025    } else {
1026        a.checked_div(b)
1027            .ok_or_else(|| EvalError::Int16OutOfRange(format!("{a} / {b}").into()))
1028    }
1029}
1030
1031#[sqlfunc(
1032    is_monotone = "(true, false)",
1033    is_infix_op = true,
1034    sqlname = "/",
1035    propagates_nulls = true
1036)]
1037fn div_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1038    if b == 0 {
1039        Err(EvalError::DivisionByZero)
1040    } else {
1041        a.checked_div(b)
1042            .ok_or_else(|| EvalError::Int32OutOfRange(format!("{a} / {b}").into()))
1043    }
1044}
1045
1046#[sqlfunc(
1047    is_monotone = "(true, false)",
1048    is_infix_op = true,
1049    sqlname = "/",
1050    propagates_nulls = true
1051)]
1052fn div_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1053    if b == 0 {
1054        Err(EvalError::DivisionByZero)
1055    } else {
1056        a.checked_div(b)
1057            .ok_or_else(|| EvalError::Int64OutOfRange(format!("{a} / {b}").into()))
1058    }
1059}
1060
1061#[sqlfunc(
1062    is_monotone = "(true, false)",
1063    is_infix_op = true,
1064    sqlname = "/",
1065    propagates_nulls = true
1066)]
1067fn div_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1068    if b == 0 {
1069        Err(EvalError::DivisionByZero)
1070    } else {
1071        Ok(a / b)
1072    }
1073}
1074
1075#[sqlfunc(
1076    is_monotone = "(true, false)",
1077    is_infix_op = true,
1078    sqlname = "/",
1079    propagates_nulls = true
1080)]
1081fn div_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1082    if b == 0 {
1083        Err(EvalError::DivisionByZero)
1084    } else {
1085        Ok(a / b)
1086    }
1087}
1088
1089#[sqlfunc(
1090    is_monotone = "(true, false)",
1091    is_infix_op = true,
1092    sqlname = "/",
1093    propagates_nulls = true
1094)]
1095fn div_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1096    if b == 0 {
1097        Err(EvalError::DivisionByZero)
1098    } else {
1099        Ok(a / b)
1100    }
1101}
1102
1103#[sqlfunc(
1104    is_monotone = "(true, false)",
1105    is_infix_op = true,
1106    sqlname = "/",
1107    propagates_nulls = true
1108)]
1109fn div_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1110    if b == 0.0f32 && !a.is_nan() {
1111        Err(EvalError::DivisionByZero)
1112    } else {
1113        let quotient = a / b;
1114        if quotient.is_infinite() && !a.is_infinite() {
1115            Err(EvalError::FloatOverflow)
1116        } else if quotient == 0.0f32 && a != 0.0f32 && !b.is_infinite() {
1117            Err(EvalError::FloatUnderflow)
1118        } else {
1119            Ok(quotient)
1120        }
1121    }
1122}
1123
1124#[sqlfunc(
1125    is_monotone = "(true, false)",
1126    is_infix_op = true,
1127    sqlname = "/",
1128    propagates_nulls = true
1129)]
1130fn div_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1131    if b == 0.0f64 && !a.is_nan() {
1132        Err(EvalError::DivisionByZero)
1133    } else {
1134        let quotient = a / b;
1135        if quotient.is_infinite() && !a.is_infinite() {
1136            Err(EvalError::FloatOverflow)
1137        } else if quotient == 0.0f64 && a != 0.0f64 && !b.is_infinite() {
1138            Err(EvalError::FloatUnderflow)
1139        } else {
1140            Ok(quotient)
1141        }
1142    }
1143}
1144
1145#[sqlfunc(
1146    is_monotone = "(true, false)",
1147    is_infix_op = true,
1148    sqlname = "/",
1149    propagates_nulls = true
1150)]
1151fn div_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1152    let mut cx = numeric::cx_datum();
1153
1154    cx.div(&mut a, &b);
1155    let cx_status = cx.status();
1156
1157    // checking the status for division by zero errors is insufficient because
1158    // the underlying library treats 0/0 as undefined and not division by zero.
1159    if b.is_zero() {
1160        Err(EvalError::DivisionByZero)
1161    } else if cx_status.overflow() {
1162        Err(EvalError::FloatOverflow)
1163    } else if cx_status.subnormal() {
1164        Err(EvalError::FloatUnderflow)
1165    } else {
1166        numeric::munge_numeric(&mut a).unwrap();
1167        Ok(a)
1168    }
1169}
1170
1171#[sqlfunc(
1172    is_monotone = "(false, false)",
1173    is_infix_op = true,
1174    sqlname = "/",
1175    propagates_nulls = true
1176)]
1177fn div_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1178    if b == 0.0 {
1179        Err(EvalError::DivisionByZero)
1180    } else {
1181        a.checked_div(b)
1182            .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} / {b}").into()))
1183    }
1184}
1185
1186#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1187fn mod_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1188    if b == 0 {
1189        Err(EvalError::DivisionByZero)
1190    } else {
1191        Ok(a.checked_rem(b).unwrap_or(0))
1192    }
1193}
1194
1195#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1196fn mod_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1197    if b == 0 {
1198        Err(EvalError::DivisionByZero)
1199    } else {
1200        Ok(a.checked_rem(b).unwrap_or(0))
1201    }
1202}
1203
1204#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1205fn mod_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1206    if b == 0 {
1207        Err(EvalError::DivisionByZero)
1208    } else {
1209        Ok(a.checked_rem(b).unwrap_or(0))
1210    }
1211}
1212
1213#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1214fn mod_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1215    if b == 0 {
1216        Err(EvalError::DivisionByZero)
1217    } else {
1218        Ok(a % b)
1219    }
1220}
1221
1222#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1223fn mod_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1224    if b == 0 {
1225        Err(EvalError::DivisionByZero)
1226    } else {
1227        Ok(a % b)
1228    }
1229}
1230
1231#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1232fn mod_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1233    if b == 0 {
1234        Err(EvalError::DivisionByZero)
1235    } else {
1236        Ok(a % b)
1237    }
1238}
1239
1240#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1241fn mod_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1242    if b == 0.0 {
1243        Err(EvalError::DivisionByZero)
1244    } else {
1245        Ok(a % b)
1246    }
1247}
1248
1249#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1250fn mod_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1251    if b == 0.0 {
1252        Err(EvalError::DivisionByZero)
1253    } else {
1254        Ok(a % b)
1255    }
1256}
1257
1258#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1259fn mod_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1260    if b.is_zero() {
1261        return Err(EvalError::DivisionByZero);
1262    }
1263    let mut cx = numeric::cx_datum();
1264    // Postgres does _not_ use IEEE 754-style remainder
1265    cx.rem(&mut a, &b);
1266    numeric::munge_numeric(&mut a).unwrap();
1267    Ok(a)
1268}
1269
1270fn neg_interval_inner(a: Interval) -> Result<Interval, EvalError> {
1271    a.checked_neg()
1272        .ok_or_else(|| EvalError::IntervalOutOfRange(a.to_string().into()))
1273}
1274
1275fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
1276    if val.is_negative() {
1277        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
1278    }
1279    if val.is_zero() {
1280        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
1281    }
1282    Ok(())
1283}
1284
1285#[sqlfunc(sqlname = "log", propagates_nulls = true)]
1286fn log_base_numeric(mut a: Numeric, mut b: Numeric) -> Result<Numeric, EvalError> {
1287    log_guard_numeric(&a, "log")?;
1288    log_guard_numeric(&b, "log")?;
1289    let mut cx = numeric::cx_datum();
1290    cx.ln(&mut a);
1291    cx.ln(&mut b);
1292    cx.div(&mut b, &a);
1293    if a.is_zero() {
1294        Err(EvalError::DivisionByZero)
1295    } else {
1296        // This division can result in slightly wrong answers due to the
1297        // limitation of dividing irrational numbers. To correct that, see if
1298        // rounding off the value from its `numeric::NUMERIC_DATUM_MAX_PRECISION
1299        // - 1`th position results in an integral value.
1300        cx.set_precision(usize::from(numeric::NUMERIC_DATUM_MAX_PRECISION - 1))
1301            .expect("reducing precision below max always succeeds");
1302        let mut integral_check = b.clone();
1303
1304        // `reduce` rounds to the context's final digit when the number of
1305        // digits in its argument exceeds its precision. We've contrived that to
1306        // happen by shrinking the context's precision by 1.
1307        cx.reduce(&mut integral_check);
1308
1309        // Reduced integral values always have a non-negative exponent.
1310        let mut b = if integral_check.exponent() >= 0 {
1311            // We believe our result should have been an integral
1312            integral_check
1313        } else {
1314            b
1315        };
1316
1317        numeric::munge_numeric(&mut b).unwrap();
1318        Ok(b)
1319    }
1320}
1321
1322#[sqlfunc(propagates_nulls = true)]
1323fn power(a: f64, b: f64) -> Result<f64, EvalError> {
1324    if a == 0.0 && b.is_sign_negative() {
1325        return Err(EvalError::Undefined(
1326            "zero raised to a negative power".into(),
1327        ));
1328    }
1329    if a.is_sign_negative() && b.fract() != 0.0 {
1330        // Equivalent to PG error:
1331        // > a negative number raised to a non-integer power yields a complex result
1332        return Err(EvalError::ComplexOutOfRange("pow".into()));
1333    }
1334    let res = a.powf(b);
1335    if res.is_infinite() {
1336        return Err(EvalError::FloatOverflow);
1337    }
1338    if res == 0.0 && a != 0.0 {
1339        return Err(EvalError::FloatUnderflow);
1340    }
1341    Ok(res)
1342}
1343
1344#[sqlfunc(propagates_nulls = true)]
1345fn uuid_generate_v5(a: uuid::Uuid, b: &str) -> uuid::Uuid {
1346    uuid::Uuid::new_v5(&a, b.as_bytes())
1347}
1348
1349#[sqlfunc(output_type = "Numeric", propagates_nulls = true)]
1350fn power_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1351    if a.is_zero() {
1352        if b.is_zero() {
1353            return Ok(Numeric::from(1));
1354        }
1355        if b.is_negative() {
1356            return Err(EvalError::Undefined(
1357                "zero raised to a negative power".into(),
1358            ));
1359        }
1360    }
1361    if a.is_negative() && b.exponent() < 0 {
1362        // Equivalent to PG error:
1363        // > a negative number raised to a non-integer power yields a complex result
1364        return Err(EvalError::ComplexOutOfRange("pow".into()));
1365    }
1366    let mut cx = numeric::cx_datum();
1367    cx.pow(&mut a, &b);
1368    let cx_status = cx.status();
1369    if cx_status.overflow() || (cx_status.invalid_operation() && !b.is_negative()) {
1370        Err(EvalError::FloatOverflow)
1371    } else if cx_status.subnormal() || cx_status.invalid_operation() {
1372        Err(EvalError::FloatUnderflow)
1373    } else {
1374        numeric::munge_numeric(&mut a).unwrap();
1375        Ok(a)
1376    }
1377}
1378
1379#[sqlfunc(propagates_nulls = true)]
1380fn get_bit(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1381    let err = EvalError::IndexOutOfRange {
1382        provided: index,
1383        valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap() - 1,
1384    };
1385
1386    let index = usize::try_from(index).map_err(|_| err.clone())?;
1387
1388    let byte_index = index / 8;
1389    let bit_index = index % 8;
1390
1391    let i = bytes
1392        .get(byte_index)
1393        .map(|b| (*b >> bit_index) & 1)
1394        .ok_or(err)?;
1395    assert!(i == 0 || i == 1);
1396    Ok(i32::from(i))
1397}
1398
1399#[sqlfunc(propagates_nulls = true)]
1400fn get_byte(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1401    let err = EvalError::IndexOutOfRange {
1402        provided: index,
1403        valid_end: i32::try_from(bytes.len()).unwrap() - 1,
1404    };
1405    let i: &u8 = bytes
1406        .get(usize::try_from(index).map_err(|_| err.clone())?)
1407        .ok_or(err)?;
1408    Ok(i32::from(*i))
1409}
1410
1411#[sqlfunc(sqlname = "constant_time_compare_bytes", propagates_nulls = true)]
1412pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
1413    verify_slices_are_equal(a, b).is_ok()
1414}
1415
1416#[sqlfunc(sqlname = "constant_time_compare_strings", propagates_nulls = true)]
1417pub fn constant_time_eq_string(a: &str, b: &str) -> bool {
1418    verify_slices_are_equal(a.as_bytes(), b.as_bytes()).is_ok()
1419}
1420
1421#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1422fn range_contains_i32<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1423    a.contains_elem(&b)
1424}
1425
1426#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1427fn range_contains_i64<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1428    a.contains_elem(&elem)
1429}
1430
1431#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1432fn range_contains_date<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1433    a.contains_elem(&elem)
1434}
1435
1436#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1437fn range_contains_numeric<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1438    a.contains_elem(&elem)
1439}
1440
1441#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1442fn range_contains_timestamp<'a>(
1443    a: Range<Datum<'a>>,
1444    elem: CheckedTimestamp<NaiveDateTime>,
1445) -> bool {
1446    a.contains_elem(&elem)
1447}
1448
1449#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1450fn range_contains_timestamp_tz<'a>(
1451    a: Range<Datum<'a>>,
1452    elem: CheckedTimestamp<DateTime<Utc>>,
1453) -> bool {
1454    a.contains_elem(&elem)
1455}
1456
1457#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1458fn range_contains_i32_rev<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1459    a.contains_elem(&b)
1460}
1461
1462#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1463fn range_contains_i64_rev<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1464    a.contains_elem(&elem)
1465}
1466
1467#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1468fn range_contains_date_rev<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1469    a.contains_elem(&elem)
1470}
1471
1472#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1473fn range_contains_numeric_rev<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1474    a.contains_elem(&elem)
1475}
1476
1477#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1478fn range_contains_timestamp_rev<'a>(
1479    a: Range<Datum<'a>>,
1480    elem: CheckedTimestamp<NaiveDateTime>,
1481) -> bool {
1482    a.contains_elem(&elem)
1483}
1484
1485#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1486fn range_contains_timestamp_tz_rev<'a>(
1487    a: Range<Datum<'a>>,
1488    elem: CheckedTimestamp<DateTime<Utc>>,
1489) -> bool {
1490    a.contains_elem(&elem)
1491}
1492
1493/// Macro to define binary function for various range operations.
1494/// Parameters:
1495/// 1. Unique binary function symbol.
1496/// 2. Range function symbol.
1497/// 3. SQL name for the function.
1498macro_rules! range_fn {
1499    ($fn:expr, $range_fn:expr, $sqlname:expr) => {
1500        paste::paste! {
1501
1502            #[sqlfunc(
1503                output_type = "bool",
1504                is_infix_op = true,
1505                sqlname = $sqlname,
1506                propagates_nulls = true
1507            )]
1508            fn [< range_ $fn >]<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a>
1509            {
1510                if a.is_null() || b.is_null() { return Datum::Null }
1511                let l = a.unwrap_range();
1512                let r = b.unwrap_range();
1513                Datum::from(Range::<Datum<'a>>::$range_fn(&l, &r))
1514            }
1515        }
1516    };
1517}
1518
1519// RangeContainsRange is either @> or <@ depending on the order of the arguments.
1520// It doesn't influence the result, but it does influence the display string.
1521range_fn!(contains_range, contains_range, "@>");
1522range_fn!(contains_range_rev, contains_range, "<@");
1523range_fn!(overlaps, overlaps, "&&");
1524range_fn!(after, after, ">>");
1525range_fn!(before, before, "<<");
1526range_fn!(overleft, overleft, "&<");
1527range_fn!(overright, overright, "&>");
1528range_fn!(adjacent, adjacent, "-|-");
1529
1530#[sqlfunc(is_infix_op = true, sqlname = "+")]
1531fn range_union<T: Copy + Ord>(l: Range<T>, r: Range<T>) -> Result<Range<T>, EvalError> {
1532    Ok(l.union(&r)?)
1533}
1534
1535#[sqlfunc(is_infix_op = true, sqlname = "*")]
1536fn range_intersection<T: Copy + Ord>(l: Range<T>, r: Range<T>) -> Range<T> {
1537    l.intersection(&r)
1538}
1539
1540#[sqlfunc(
1541    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
1542    is_infix_op = true,
1543    sqlname = "-",
1544    propagates_nulls = true,
1545    introduces_nulls = false
1546)]
1547fn range_difference<'a>(
1548    l: Range<Datum<'a>>,
1549    r: Range<Datum<'a>>,
1550) -> Result<Range<Datum<'a>>, EvalError> {
1551    Ok(l.difference(&r)?)
1552}
1553
1554#[sqlfunc(is_infix_op = true, sqlname = "=", negate = "Some(NotEq.into())")]
1555fn eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1556    // SQL equality demands that if either input is null, then the result should be null. However,
1557    // we don't need to handle this case here; it is handled when `BinaryFunc::eval` checks
1558    // `propagates_nulls`.
1559    a == b
1560}
1561
1562#[sqlfunc(is_infix_op = true, sqlname = "!=", negate = "Some(Eq.into())")]
1563fn not_eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1564    a != b
1565}
1566
1567#[sqlfunc(
1568    is_monotone = "(true, true)",
1569    is_infix_op = true,
1570    sqlname = "<",
1571    negate = "Some(Gte.into())"
1572)]
1573fn lt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1574    a < b
1575}
1576
1577#[sqlfunc(
1578    is_monotone = "(true, true)",
1579    is_infix_op = true,
1580    sqlname = "<=",
1581    negate = "Some(Gt.into())"
1582)]
1583fn lte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1584    a <= b
1585}
1586
1587#[sqlfunc(
1588    is_monotone = "(true, true)",
1589    is_infix_op = true,
1590    sqlname = ">",
1591    negate = "Some(Lte.into())"
1592)]
1593fn gt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1594    a > b
1595}
1596
1597#[sqlfunc(
1598    is_monotone = "(true, true)",
1599    is_infix_op = true,
1600    sqlname = ">=",
1601    negate = "Some(Lt.into())"
1602)]
1603fn gte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1604    a >= b
1605}
1606
1607#[sqlfunc(sqlname = "tocharts", propagates_nulls = true)]
1608fn to_char_timestamp_format(ts: CheckedTimestamp<chrono::NaiveDateTime>, format: &str) -> String {
1609    let fmt = DateTimeFormat::compile(format);
1610    fmt.render(&*ts)
1611}
1612
1613#[sqlfunc(sqlname = "tochartstz", propagates_nulls = true)]
1614fn to_char_timestamp_tz_format(
1615    ts: CheckedTimestamp<chrono::DateTime<Utc>>,
1616    format: &str,
1617) -> String {
1618    let fmt = DateTimeFormat::compile(format);
1619    fmt.render(&*ts)
1620}
1621
1622#[sqlfunc(sqlname = "->", is_infix_op = true)]
1623fn jsonb_get_int64<'a>(a: JsonbRef<'a>, i: i64) -> Option<JsonbRef<'a>> {
1624    match a.into_datum() {
1625        Datum::List(list) => {
1626            let i = if i >= 0 {
1627                usize::cast_from(i.unsigned_abs())
1628            } else {
1629                // index backwards from the end
1630                let i = usize::cast_from(i.unsigned_abs());
1631                (list.iter().count()).wrapping_sub(i)
1632            };
1633            let v = list.iter().nth(i)?;
1634            // `v` should be valid jsonb because it came from a jsonb list, but we don't
1635            // panic on mismatch to avoid bringing down the whole system on corrupt data.
1636            // Instead, we'll return None.
1637            JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1638        }
1639        Datum::Map(_) => None,
1640        _ => {
1641            // I have no idea why postgres does this, but we're stuck with it
1642            (i == 0 || i == -1).then_some(a)
1643        }
1644    }
1645}
1646
1647#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1648fn jsonb_get_int64_stringify<'a>(
1649    a: JsonbRef<'a>,
1650    i: i64,
1651    temp_storage: &'a RowArena,
1652) -> Option<&'a str> {
1653    let json = jsonb_get_int64(a, i)?;
1654    jsonb_stringify(json.into_datum(), temp_storage)
1655}
1656
1657#[sqlfunc(sqlname = "->", is_infix_op = true)]
1658fn jsonb_get_string<'a>(a: JsonbRef<'a>, k: &str) -> Option<JsonbRef<'a>> {
1659    let dict = DatumMap::try_from_result(Ok::<_, ()>(a.into_datum())).ok()?;
1660    let v = dict.iter().find(|(k2, _v)| k == *k2).map(|(_k, v)| v)?;
1661    JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1662}
1663
1664#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1665fn jsonb_get_string_stringify<'a>(
1666    a: JsonbRef<'a>,
1667    k: &str,
1668    temp_storage: &'a RowArena,
1669) -> Option<&'a str> {
1670    let v = jsonb_get_string(a, k)?;
1671    jsonb_stringify(v.into_datum(), temp_storage)
1672}
1673
1674#[sqlfunc(sqlname = "#>", is_infix_op = true)]
1675fn jsonb_get_path<'a>(mut json: JsonbRef<'a>, b: Array<'a>) -> Option<JsonbRef<'a>> {
1676    let path = b.elements();
1677    for key in path.iter() {
1678        let key = match key {
1679            Datum::String(s) => s,
1680            Datum::Null => return None,
1681            _ => unreachable!("keys in jsonb_get_path known to be strings"),
1682        };
1683        let v = match json.into_datum() {
1684            Datum::Map(map) => map.iter().find(|(k, _)| key == *k).map(|(_k, v)| v),
1685            Datum::List(list) => {
1686                let i = strconv::parse_int64(key).ok()?;
1687                let i = if i >= 0 {
1688                    usize::cast_from(i.unsigned_abs())
1689                } else {
1690                    // index backwards from the end
1691                    let i = usize::cast_from(i.unsigned_abs());
1692                    (list.iter().count()).wrapping_sub(i)
1693                };
1694                list.iter().nth(i)
1695            }
1696            _ => return None,
1697        }?;
1698        json = JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()?;
1699    }
1700    Some(json)
1701}
1702
1703#[sqlfunc(sqlname = "#>>", is_infix_op = true)]
1704fn jsonb_get_path_stringify<'a>(
1705    a: JsonbRef<'a>,
1706    b: Array<'a>,
1707    temp_storage: &'a RowArena,
1708) -> Option<&'a str> {
1709    let json = jsonb_get_path(a, b)?;
1710    jsonb_stringify(json.into_datum(), temp_storage)
1711}
1712
1713#[sqlfunc(is_infix_op = true, sqlname = "?")]
1714fn jsonb_contains_string<'a>(a: JsonbRef<'a>, k: &str) -> bool {
1715    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1716    // When the left operand is SQL NULL (NULL::jsonb), JsonbRef::try_from_result rejects it,
1717    // so the binary evaluator never calls this function and returns NULL (see binary.rs).
1718    // So, this function only runs for non-null jsonb; a.into_datum() never sees Datum::Null.
1719    match a.into_datum() {
1720        Datum::List(list) => list.iter().any(|k2| Datum::from(k) == k2),
1721        Datum::Map(dict) => dict.iter().any(|(k2, _v)| k == k2),
1722        Datum::String(string) => string == k,
1723        _ => false,
1724    }
1725}
1726
1727#[sqlfunc(is_infix_op = true, sqlname = "?", propagates_nulls = true)]
1728// Map keys are always text.
1729fn map_contains_key<'a>(map: DatumMap<'a>, k: &str) -> bool {
1730    map.iter().any(|(k2, _v)| k == k2)
1731}
1732
1733#[sqlfunc(is_infix_op = true, sqlname = "?&")]
1734fn map_contains_all_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1735    keys.elements()
1736        .iter()
1737        .all(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1738}
1739
1740#[sqlfunc(is_infix_op = true, sqlname = "?|", propagates_nulls = true)]
1741fn map_contains_any_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1742    keys.elements()
1743        .iter()
1744        .any(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1745}
1746
1747#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1748fn map_contains_map<'a>(map_a: DatumMap<'a>, b: DatumMap<'a>) -> bool {
1749    b.iter().all(|(b_key, b_val)| {
1750        map_a
1751            .iter()
1752            .any(|(a_key, a_val)| (a_key == b_key) && (a_val == b_val))
1753    })
1754}
1755
1756#[sqlfunc(is_infix_op = true, sqlname = "->", propagates_nulls = true)]
1757fn map_get_value<'a, T: FromDatum<'a>>(a: DatumMap<'a, T>, target_key: &str) -> Option<T> {
1758    a.typed_iter()
1759        .find(|(key, _v)| target_key == *key)
1760        .map(|(_k, v)| v)
1761}
1762
1763#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1764fn list_contains_list<'a>(a: ExcludeNull<DatumList<'a>>, b: ExcludeNull<DatumList<'a>>) -> bool {
1765    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
1766    if b.iter().contains(&Datum::Null) {
1767        false
1768    } else {
1769        b.iter()
1770            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
1771    }
1772}
1773
1774#[sqlfunc(is_infix_op = true, sqlname = "<@")]
1775fn list_contains_list_rev<'a>(
1776    a: ExcludeNull<DatumList<'a>>,
1777    b: ExcludeNull<DatumList<'a>>,
1778) -> bool {
1779    list_contains_list(b, a)
1780}
1781
1782// TODO(jamii) nested loops are possibly not the fastest way to do this
1783#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1784fn jsonb_contains_jsonb<'a>(a: JsonbRef<'a>, b: JsonbRef<'a>) -> bool {
1785    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1786    fn contains(a: Datum, b: Datum, at_top_level: bool) -> bool {
1787        match (a, b) {
1788            (Datum::JsonNull, Datum::JsonNull) => true,
1789            (Datum::False, Datum::False) => true,
1790            (Datum::True, Datum::True) => true,
1791            (Datum::Numeric(a), Datum::Numeric(b)) => a == b,
1792            (Datum::String(a), Datum::String(b)) => a == b,
1793            (Datum::List(a), Datum::List(b)) => b
1794                .iter()
1795                .all(|b_elem| a.iter().any(|a_elem| contains(a_elem, b_elem, false))),
1796            (Datum::Map(a), Datum::Map(b)) => b.iter().all(|(b_key, b_val)| {
1797                a.iter()
1798                    .any(|(a_key, a_val)| (a_key == b_key) && contains(a_val, b_val, false))
1799            }),
1800
1801            // fun special case
1802            (Datum::List(a), b) => {
1803                at_top_level && a.iter().any(|a_elem| contains(a_elem, b, false))
1804            }
1805
1806            _ => false,
1807        }
1808    }
1809    contains(a.into_datum(), b.into_datum(), true)
1810}
1811
1812#[sqlfunc(is_infix_op = true, sqlname = "||")]
1813fn jsonb_concat<'a>(
1814    a: JsonbRef<'a>,
1815    b: JsonbRef<'a>,
1816    temp_storage: &'a RowArena,
1817) -> Option<JsonbRef<'a>> {
1818    let res = match (a.into_datum(), b.into_datum()) {
1819        (Datum::Map(dict_a), Datum::Map(dict_b)) => {
1820            let mut pairs = dict_b.iter().chain(dict_a.iter()).collect::<Vec<_>>();
1821            // stable sort, so if keys collide dedup prefers dict_b
1822            pairs.sort_by(|(k1, _v1), (k2, _v2)| k1.cmp(k2));
1823            pairs.dedup_by(|(k1, _v1), (k2, _v2)| k1 == k2);
1824            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1825        }
1826        (Datum::List(list_a), Datum::List(list_b)) => {
1827            let elems = list_a.iter().chain(list_b.iter());
1828            temp_storage.make_datum(|packer| packer.push_list(elems))
1829        }
1830        (Datum::List(list_a), b) => {
1831            let elems = list_a.iter().chain(Some(b));
1832            temp_storage.make_datum(|packer| packer.push_list(elems))
1833        }
1834        (a, Datum::List(list_b)) => {
1835            let elems = Some(a).into_iter().chain(list_b.iter());
1836            temp_storage.make_datum(|packer| packer.push_list(elems))
1837        }
1838        _ => return None,
1839    };
1840    Some(JsonbRef::from_datum(res))
1841}
1842
1843#[sqlfunc(
1844    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1845    is_infix_op = true,
1846    sqlname = "-",
1847    propagates_nulls = true,
1848    introduces_nulls = true
1849)]
1850fn jsonb_delete_int64<'a>(a: Datum<'a>, i: i64, temp_storage: &'a RowArena) -> Datum<'a> {
1851    match a {
1852        Datum::List(list) => {
1853            let i = if i >= 0 {
1854                usize::cast_from(i.unsigned_abs())
1855            } else {
1856                // index backwards from the end
1857                let i = usize::cast_from(i.unsigned_abs());
1858                (list.iter().count()).wrapping_sub(i)
1859            };
1860            let elems = list
1861                .iter()
1862                .enumerate()
1863                .filter(|(i2, _e)| i != *i2)
1864                .map(|(_, e)| e);
1865            temp_storage.make_datum(|packer| packer.push_list(elems))
1866        }
1867        _ => Datum::Null,
1868    }
1869}
1870
1871#[sqlfunc(
1872    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1873    is_infix_op = true,
1874    sqlname = "-",
1875    propagates_nulls = true,
1876    introduces_nulls = true
1877)]
1878fn jsonb_delete_string<'a>(a: Datum<'a>, k: &str, temp_storage: &'a RowArena) -> Datum<'a> {
1879    match a {
1880        Datum::List(list) => {
1881            let elems = list.iter().filter(|e| Datum::from(k) != *e);
1882            temp_storage.make_datum(|packer| packer.push_list(elems))
1883        }
1884        Datum::Map(dict) => {
1885            let pairs = dict.iter().filter(|(k2, _v)| k != *k2);
1886            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1887        }
1888        _ => Datum::Null,
1889    }
1890}
1891
1892#[sqlfunc(
1893    sqlname = "extractiv",
1894    propagates_nulls = true,
1895    introduces_nulls = false
1896)]
1897fn date_part_interval_numeric(units: &str, b: Interval) -> Result<Numeric, EvalError> {
1898    match units.parse() {
1899        Ok(units) => Ok(date_part_interval_inner::<Numeric>(units, b)?),
1900        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1901    }
1902}
1903
1904#[sqlfunc(
1905    sqlname = "date_partiv",
1906    propagates_nulls = true,
1907    introduces_nulls = false
1908)]
1909fn date_part_interval_f64(units: &str, b: Interval) -> Result<f64, EvalError> {
1910    match units.parse() {
1911        Ok(units) => Ok(date_part_interval_inner::<f64>(units, b)?),
1912        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1913    }
1914}
1915
1916#[sqlfunc(
1917    sqlname = "extractt",
1918    propagates_nulls = true,
1919    introduces_nulls = false
1920)]
1921fn date_part_time_numeric(units: &str, b: chrono::NaiveTime) -> Result<Numeric, EvalError> {
1922    match units.parse() {
1923        Ok(units) => Ok(date_part_time_inner::<Numeric>(units, b)?),
1924        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1925    }
1926}
1927
1928#[sqlfunc(
1929    sqlname = "date_partt",
1930    propagates_nulls = true,
1931    introduces_nulls = false
1932)]
1933fn date_part_time_f64(units: &str, b: chrono::NaiveTime) -> Result<f64, EvalError> {
1934    match units.parse() {
1935        Ok(units) => Ok(date_part_time_inner::<f64>(units, b)?),
1936        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1937    }
1938}
1939
1940#[sqlfunc(sqlname = "extractts", propagates_nulls = true)]
1941fn date_part_timestamp_timestamp_numeric(
1942    units: &str,
1943    ts: CheckedTimestamp<NaiveDateTime>,
1944) -> Result<Numeric, EvalError> {
1945    match units.parse() {
1946        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1947        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1948    }
1949}
1950
1951#[sqlfunc(sqlname = "extracttstz", propagates_nulls = true)]
1952fn date_part_timestamp_timestamp_tz_numeric(
1953    units: &str,
1954    ts: CheckedTimestamp<DateTime<Utc>>,
1955) -> Result<Numeric, EvalError> {
1956    match units.parse() {
1957        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1958        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1959    }
1960}
1961
1962#[sqlfunc(sqlname = "date_partts", propagates_nulls = true)]
1963fn date_part_timestamp_timestamp_f64(
1964    units: &str,
1965    ts: CheckedTimestamp<NaiveDateTime>,
1966) -> Result<f64, EvalError> {
1967    match units.parse() {
1968        Ok(units) => date_part_timestamp_inner(units, &*ts),
1969        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1970    }
1971}
1972
1973#[sqlfunc(sqlname = "date_parttstz", propagates_nulls = true)]
1974fn date_part_timestamp_timestamp_tz_f64(
1975    units: &str,
1976    ts: CheckedTimestamp<DateTime<Utc>>,
1977) -> Result<f64, EvalError> {
1978    match units.parse() {
1979        Ok(units) => date_part_timestamp_inner(units, &*ts),
1980        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1981    }
1982}
1983
1984#[sqlfunc(sqlname = "extractd", propagates_nulls = true)]
1985fn extract_date_units(units: &str, b: Date) -> Result<Numeric, EvalError> {
1986    match units.parse() {
1987        Ok(units) => Ok(extract_date_inner(units, b.into())?),
1988        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1989    }
1990}
1991
1992pub fn date_bin<T>(
1993    stride: Interval,
1994    source: CheckedTimestamp<T>,
1995    origin: CheckedTimestamp<T>,
1996) -> Result<CheckedTimestamp<T>, EvalError>
1997where
1998    T: TimestampLike,
1999{
2000    if stride.months != 0 {
2001        return Err(EvalError::DateBinOutOfRange(
2002            "timestamps cannot be binned into intervals containing months or years".into(),
2003        ));
2004    }
2005
2006    let stride_ns = match stride.duration_as_chrono().num_nanoseconds() {
2007        Some(ns) if ns <= 0 => Err(EvalError::DateBinOutOfRange(
2008            "stride must be greater than zero".into(),
2009        )),
2010        Some(ns) => Ok(ns),
2011        None => Err(EvalError::DateBinOutOfRange(
2012            format!("stride cannot exceed {}/{} nanoseconds", i64::MAX, i64::MIN,).into(),
2013        )),
2014    }?;
2015
2016    // Make sure the returned timestamp is at the start of the bin, even if the
2017    // origin is in the future. We do this here because `T` is not `Copy` and
2018    // gets moved by its subtraction operation.
2019    let sub_stride = origin > source;
2020
2021    let tm_diff = (source - origin.clone()).num_nanoseconds().ok_or_else(|| {
2022        EvalError::DateBinOutOfRange(
2023            "source and origin must not differ more than 2^63 nanoseconds".into(),
2024        )
2025    })?;
2026
2027    let mut tm_delta = tm_diff - tm_diff % stride_ns;
2028
2029    if sub_stride {
2030        tm_delta -= stride_ns;
2031    }
2032
2033    let res = origin
2034        .checked_add_signed(Duration::nanoseconds(tm_delta))
2035        .ok_or(EvalError::TimestampOutOfRange)?;
2036    Ok(CheckedTimestamp::from_timestamplike(res)?)
2037}
2038
2039#[sqlfunc(is_monotone = "(true, true)", sqlname = "bin_unix_epoch_timestamp")]
2040fn date_bin_timestamp(
2041    stride: Interval,
2042    source: CheckedTimestamp<NaiveDateTime>,
2043) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2044    let origin =
2045        CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap().naive_utc())
2046            .expect("must fit");
2047    date_bin(stride, source, origin)
2048}
2049
2050#[sqlfunc(is_monotone = "(true, true)", sqlname = "bin_unix_epoch_timestamptz")]
2051fn date_bin_timestamp_tz(
2052    stride: Interval,
2053    source: CheckedTimestamp<DateTime<Utc>>,
2054) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2055    let origin = CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap())
2056        .expect("must fit");
2057    date_bin(stride, source, origin)
2058}
2059
2060#[sqlfunc(sqlname = "date_truncts", propagates_nulls = true)]
2061fn date_trunc_units_timestamp(
2062    units: &str,
2063    ts: CheckedTimestamp<NaiveDateTime>,
2064) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2065    match units.parse() {
2066        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2067        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2068    }
2069}
2070
2071#[sqlfunc(sqlname = "date_trunctstz", propagates_nulls = true)]
2072fn date_trunc_units_timestamp_tz(
2073    units: &str,
2074    ts: CheckedTimestamp<DateTime<Utc>>,
2075) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2076    match units.parse() {
2077        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2078        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2079    }
2080}
2081
2082#[sqlfunc(sqlname = "date_trunciv", propagates_nulls = true)]
2083fn date_trunc_interval(units: &str, mut interval: Interval) -> Result<Interval, EvalError> {
2084    let dtf = units
2085        .parse()
2086        .map_err(|_| EvalError::UnknownUnits(units.into()))?;
2087
2088    interval
2089        .truncate_low_fields(dtf, Some(0), RoundBehavior::Truncate)
2090        .expect(
2091            "truncate_low_fields should not fail with max_precision 0 and RoundBehavior::Truncate",
2092        );
2093    Ok(interval)
2094}
2095
2096/// Parses a named timezone like `EST` or `America/New_York`, or a fixed-offset timezone like `-05:00`.
2097///
2098/// The interpretation of fixed offsets depend on whether the POSIX or ISO 8601 standard is being
2099/// used.
2100pub(crate) fn parse_timezone(tz: &str, spec: TimezoneSpec) -> Result<Timezone, EvalError> {
2101    Timezone::parse(tz, spec).map_err(|_| EvalError::InvalidTimezone(tz.into()))
2102}
2103
2104/// Converts the time datum `b`, which is assumed to be in UTC, to the timezone that the interval datum `a` is assumed
2105/// to represent. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2106/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2107#[sqlfunc(sqlname = "timezoneit")]
2108fn timezone_interval_time_binary(
2109    interval: Interval,
2110    time: chrono::NaiveTime,
2111) -> Result<chrono::NaiveTime, EvalError> {
2112    if interval.months != 0 {
2113        Err(EvalError::InvalidTimezoneInterval)
2114    } else {
2115        Ok(time.overflowing_add_signed(interval.duration_as_chrono()).0)
2116    }
2117}
2118
2119/// Converts the timestamp datum `b`, which is assumed to be in the time of the timezone datum `a` to a timestamptz
2120/// in UTC. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2121/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2122#[sqlfunc(sqlname = "timezoneits")]
2123fn timezone_interval_timestamp_binary(
2124    interval: Interval,
2125    ts: CheckedTimestamp<NaiveDateTime>,
2126) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2127    if interval.months != 0 {
2128        Err(EvalError::InvalidTimezoneInterval)
2129    } else {
2130        match ts.checked_sub_signed(interval.duration_as_chrono()) {
2131            Some(sub) => Ok(DateTime::from_naive_utc_and_offset(sub, Utc).try_into()?),
2132            None => Err(EvalError::TimestampOutOfRange),
2133        }
2134    }
2135}
2136
2137/// Converts the UTC timestamptz datum `b`, to the local timestamp of the timezone datum `a`.
2138/// The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2139/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2140#[sqlfunc(sqlname = "timezoneitstz")]
2141fn timezone_interval_timestamp_tz_binary(
2142    interval: Interval,
2143    tstz: CheckedTimestamp<DateTime<Utc>>,
2144) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2145    if interval.months != 0 {
2146        return Err(EvalError::InvalidTimezoneInterval);
2147    }
2148    match tstz
2149        .naive_utc()
2150        .checked_add_signed(interval.duration_as_chrono())
2151    {
2152        Some(dt) => Ok(dt.try_into()?),
2153        None => Err(EvalError::TimestampOutOfRange),
2154    }
2155}
2156
2157#[sqlfunc(
2158    output_type_expr = r#"SqlScalarType::Record {
2159                fields: [
2160                    ("abbrev".into(), SqlScalarType::String.nullable(false)),
2161                    ("base_utc_offset".into(), SqlScalarType::Interval.nullable(false)),
2162                    ("dst_offset".into(), SqlScalarType::Interval.nullable(false)),
2163                ].into(),
2164                custom_id: None,
2165            }.nullable(true)"#,
2166    propagates_nulls = true,
2167    introduces_nulls = false
2168)]
2169fn timezone_offset<'a>(
2170    tz_str: &str,
2171    b: CheckedTimestamp<chrono::DateTime<Utc>>,
2172    temp_storage: &'a RowArena,
2173) -> Result<Datum<'a>, EvalError> {
2174    let tz = match Tz::from_str_insensitive(tz_str) {
2175        Ok(tz) => tz,
2176        Err(_) => return Err(EvalError::InvalidIanaTimezoneId(tz_str.into())),
2177    };
2178    let offset = tz.offset_from_utc_datetime(&b.naive_utc());
2179    Ok(temp_storage.make_datum(|packer| {
2180        packer.push_list_with(|packer| {
2181            packer.push(Datum::from(offset.abbreviation()));
2182            packer.push(Datum::from(offset.base_utc_offset()));
2183            packer.push(Datum::from(offset.dst_offset()));
2184        });
2185    }))
2186}
2187
2188/// Determines if an mz_aclitem contains one of the specified privileges. This will return true if
2189/// any of the listed privileges are contained in the mz_aclitem.
2190#[sqlfunc(
2191    sqlname = "mz_aclitem_contains_privilege",
2192    output_type = "bool",
2193    propagates_nulls = true
2194)]
2195fn mz_acl_item_contains_privilege(
2196    mz_acl_item: MzAclItem,
2197    privileges: &str,
2198) -> Result<bool, EvalError> {
2199    let acl_mode = AclMode::parse_multiple_privileges(privileges)
2200        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
2201    let contains = !mz_acl_item.acl_mode.intersection(acl_mode).is_empty();
2202    Ok(contains)
2203}
2204
2205#[sqlfunc]
2206// transliterated from postgres/src/backend/utils/adt/misc.c
2207fn parse_ident<'a>(ident: &'a str, strict: bool) -> Result<ArrayRustType<Cow<'a, str>>, EvalError> {
2208    fn is_ident_start(c: char) -> bool {
2209        matches!(c, 'A'..='Z' | 'a'..='z' | '_' | '\u{80}'..=char::MAX)
2210    }
2211
2212    fn is_ident_cont(c: char) -> bool {
2213        matches!(c, '0'..='9' | '$') || is_ident_start(c)
2214    }
2215
2216    let mut elems = vec![];
2217    let buf = &mut LexBuf::new(ident);
2218
2219    let mut after_dot = false;
2220
2221    buf.take_while(|ch| ch.is_ascii_whitespace());
2222
2223    loop {
2224        let mut missing_ident = true;
2225
2226        let c = buf.next();
2227
2228        if c == Some('"') {
2229            let s = buf.take_while(|ch| !matches!(ch, '"'));
2230
2231            if buf.next() != Some('"') {
2232                return Err(EvalError::InvalidIdentifier {
2233                    ident: ident.into(),
2234                    detail: Some("String has unclosed double quotes.".into()),
2235                });
2236            }
2237            elems.push(Cow::Borrowed(s));
2238            missing_ident = false;
2239        } else if c.map(is_ident_start).unwrap_or(false) {
2240            buf.prev();
2241            let s = buf.take_while(is_ident_cont);
2242            elems.push(Cow::Owned(s.to_ascii_lowercase()));
2243            missing_ident = false;
2244        }
2245
2246        if missing_ident {
2247            if c == Some('.') {
2248                return Err(EvalError::InvalidIdentifier {
2249                    ident: ident.into(),
2250                    detail: Some("No valid identifier before \".\".".into()),
2251                });
2252            } else if after_dot {
2253                return Err(EvalError::InvalidIdentifier {
2254                    ident: ident.into(),
2255                    detail: Some("No valid identifier after \".\".".into()),
2256                });
2257            } else {
2258                return Err(EvalError::InvalidIdentifier {
2259                    ident: ident.into(),
2260                    detail: None,
2261                });
2262            }
2263        }
2264
2265        buf.take_while(|ch| ch.is_ascii_whitespace());
2266
2267        match buf.next() {
2268            Some('.') => {
2269                after_dot = true;
2270
2271                buf.take_while(|ch| ch.is_ascii_whitespace());
2272            }
2273            Some(_) if strict => {
2274                return Err(EvalError::InvalidIdentifier {
2275                    ident: ident.into(),
2276                    detail: None,
2277                });
2278            }
2279            _ => break,
2280        }
2281    }
2282
2283    Ok(elems.into())
2284}
2285
2286fn regexp_split_to_array_re<'a>(
2287    text: &str,
2288    regexp: &Regex,
2289    temp_storage: &'a RowArena,
2290) -> Result<Datum<'a>, EvalError> {
2291    let found = mz_regexp::regexp_split_to_array(text, regexp);
2292    let mut row = Row::default();
2293    let mut packer = row.packer();
2294    packer.try_push_array(
2295        &[ArrayDimension {
2296            lower_bound: 1,
2297            length: found.len(),
2298        }],
2299        found.into_iter().map(Datum::String),
2300    )?;
2301    Ok(temp_storage.push_unary_row(row))
2302}
2303
2304#[sqlfunc(propagates_nulls = true)]
2305fn pretty_sql<'a>(sql: &str, width: i32, temp_storage: &'a RowArena) -> Result<&'a str, EvalError> {
2306    let width =
2307        usize::try_from(width).map_err(|_| EvalError::PrettyError("invalid width".into()))?;
2308    let pretty = pretty_str(
2309        sql,
2310        PrettyConfig {
2311            width,
2312            format_mode: FormatMode::Simple,
2313        },
2314    )
2315    .map_err(|e| EvalError::PrettyError(e.to_string().into()))?;
2316    let pretty = temp_storage.push_string(pretty);
2317    Ok(pretty)
2318}
2319
2320#[sqlfunc]
2321fn redact_sql(sql: &str) -> Result<String, EvalError> {
2322    let stmts = mz_sql_parser::parser::parse_statements(sql)
2323        .map_err(|e| EvalError::RedactError(e.to_string().into()))?;
2324    match stmts.len() {
2325        1 => Ok(stmts[0].ast.to_ast_string_redacted()),
2326        n => Err(EvalError::RedactError(
2327            format!("expected a single statement, found {n}").into(),
2328        )),
2329    }
2330}
2331
2332#[sqlfunc(propagates_nulls = true)]
2333fn starts_with(a: &str, b: &str) -> bool {
2334    a.starts_with(b)
2335}
2336
2337#[sqlfunc(
2338    sqlname = "||",
2339    is_infix_op = true,
2340    propagates_nulls = true,
2341    // Text concatenation is monotonic in its second argument, because if I change the
2342    // second argument but don't change the first argument, then we won't find a difference
2343    // in that part of the concatenation result that came from the first argument, so we'll
2344    // find the difference that comes from changing the second argument.
2345    // (It's not monotonic in its first argument, because e.g.,
2346    // 'A' < 'AA' but 'AZ' > 'AAZ'.)
2347    is_monotone = (false, true),
2348)]
2349fn text_concat_binary(a: &str, b: &str) -> Result<String, EvalError> {
2350    if a.len() + b.len() > MAX_STRING_FUNC_RESULT_BYTES {
2351        return Err(EvalError::LengthTooLarge);
2352    }
2353    let mut buf = String::with_capacity(a.len() + b.len());
2354    buf.push_str(a);
2355    buf.push_str(b);
2356    Ok(buf)
2357}
2358
2359#[sqlfunc(propagates_nulls = true, introduces_nulls = false)]
2360fn like_escape<'a>(
2361    pattern: &str,
2362    b: &str,
2363    temp_storage: &'a RowArena,
2364) -> Result<&'a str, EvalError> {
2365    let escape = like_pattern::EscapeBehavior::from_str(b)?;
2366    let normalized = like_pattern::normalize_pattern(pattern, escape)?;
2367    Ok(temp_storage.push_string(normalized))
2368}
2369
2370#[sqlfunc(is_infix_op = true, sqlname = "like")]
2371fn is_like_match_case_sensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2372    like_pattern::compile(pattern, false).map(|needle| needle.is_match(haystack))
2373}
2374
2375#[sqlfunc(is_infix_op = true, sqlname = "ilike")]
2376fn is_like_match_case_insensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2377    like_pattern::compile(pattern, true).map(|needle| needle.is_match(haystack))
2378}
2379
2380#[sqlfunc(is_infix_op = true, sqlname = "~")]
2381fn is_regexp_match_case_sensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2382    let regex = build_regex(needle, "")?;
2383    Ok(regex.is_match(haystack))
2384}
2385
2386#[sqlfunc(is_infix_op = true, sqlname = "~*")]
2387fn is_regexp_match_case_insensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2388    let regex = build_regex(needle, "i")?;
2389    Ok(regex.is_match(haystack))
2390}
2391
2392fn regexp_match_static<'a>(
2393    haystack: Datum<'a>,
2394    temp_storage: &'a RowArena,
2395    needle: &regex::Regex,
2396) -> Result<Datum<'a>, EvalError> {
2397    let mut row = Row::default();
2398    let mut packer = row.packer();
2399    if needle.captures_len() > 1 {
2400        // The regex contains capture groups, so return an array containing the
2401        // matched text in each capture group, unless the entire match fails.
2402        // Individual capture groups may also be null if that group did not
2403        // participate in the match.
2404        match needle.captures(haystack.unwrap_str()) {
2405            None => packer.push(Datum::Null),
2406            Some(captures) => packer.try_push_array(
2407                &[ArrayDimension {
2408                    lower_bound: 1,
2409                    length: captures.len() - 1,
2410                }],
2411                // Skip the 0th capture group, which is the whole match.
2412                captures.iter().skip(1).map(|mtch| match mtch {
2413                    None => Datum::Null,
2414                    Some(mtch) => Datum::String(mtch.as_str()),
2415                }),
2416            )?,
2417        }
2418    } else {
2419        // The regex contains no capture groups, so return a one-element array
2420        // containing the match, or null if there is no match.
2421        match needle.find(haystack.unwrap_str()) {
2422            None => packer.push(Datum::Null),
2423            Some(mtch) => packer.try_push_array(
2424                &[ArrayDimension {
2425                    lower_bound: 1,
2426                    length: 1,
2427                }],
2428                iter::once(Datum::String(mtch.as_str())),
2429            )?,
2430        };
2431    };
2432    Ok(temp_storage.push_unary_row(row))
2433}
2434
2435/// Sets `limit` based on the presence of 'g' in `flags` for use in `Regex::replacen`,
2436/// and removes 'g' from `flags` if present.
2437pub(crate) fn regexp_replace_parse_flags(flags: &str) -> (usize, Cow<'_, str>) {
2438    // 'g' means to replace all instead of the first. Use a Cow to avoid allocating in the fast
2439    // path. We could switch build_regex to take an iter which would also achieve that.
2440    let (limit, flags) = if flags.contains('g') {
2441        let flags = flags.replace('g', "");
2442        (0, Cow::Owned(flags))
2443    } else {
2444        (1, Cow::Borrowed(flags))
2445    };
2446    (limit, flags)
2447}
2448
2449pub fn build_regex(needle: &str, flags: &str) -> Result<Regex, EvalError> {
2450    let mut case_insensitive = false;
2451    // Note: Postgres accepts it when both flags are present, taking the last one. We do the same.
2452    for f in flags.chars() {
2453        match f {
2454            'i' => {
2455                case_insensitive = true;
2456            }
2457            'c' => {
2458                case_insensitive = false;
2459            }
2460            _ => return Err(EvalError::InvalidRegexFlag(f)),
2461        }
2462    }
2463    Ok(Regex::new(needle, case_insensitive)?)
2464}
2465
2466#[sqlfunc(sqlname = "repeat")]
2467fn repeat_string(string: &str, count: i32) -> Result<String, EvalError> {
2468    let len = usize::try_from(count).unwrap_or(0);
2469    if (len * string.len()) > MAX_STRING_FUNC_RESULT_BYTES {
2470        return Err(EvalError::LengthTooLarge);
2471    }
2472    Ok(string.repeat(len))
2473}
2474
2475/// Constructs a new zero or one dimensional array out of an arbitrary number of
2476/// scalars.
2477///
2478/// If `datums` is empty, constructs a zero-dimensional array. Otherwise,
2479/// constructs a one dimensional array whose lower bound is one and whose length
2480/// is equal to `datums.len()`.
2481fn array_create_scalar<'a>(
2482    datums: &[Datum<'a>],
2483    temp_storage: &'a RowArena,
2484) -> Result<Datum<'a>, EvalError> {
2485    let mut dims = &[ArrayDimension {
2486        lower_bound: 1,
2487        length: datums.len(),
2488    }][..];
2489    if datums.is_empty() {
2490        // Per PostgreSQL, empty arrays are represented with zero dimensions,
2491        // not one dimension of zero length. We write this condition a little
2492        // strangely to satisfy the borrow checker while avoiding an allocation.
2493        dims = &[];
2494    }
2495    let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(dims, datums))?;
2496    Ok(datum)
2497}
2498
2499fn stringify_datum<'a, B>(
2500    buf: &mut B,
2501    d: Datum<'a>,
2502    ty: &SqlScalarType,
2503) -> Result<strconv::Nestable, EvalError>
2504where
2505    B: FormatBuffer,
2506{
2507    use SqlScalarType::*;
2508    match &ty {
2509        AclItem => Ok(strconv::format_acl_item(buf, d.unwrap_acl_item())),
2510        Bool => Ok(strconv::format_bool(buf, d.unwrap_bool())),
2511        Int16 => Ok(strconv::format_int16(buf, d.unwrap_int16())),
2512        Int32 => Ok(strconv::format_int32(buf, d.unwrap_int32())),
2513        Int64 => Ok(strconv::format_int64(buf, d.unwrap_int64())),
2514        UInt16 => Ok(strconv::format_uint16(buf, d.unwrap_uint16())),
2515        UInt32 | Oid | RegClass | RegProc | RegType => {
2516            Ok(strconv::format_uint32(buf, d.unwrap_uint32()))
2517        }
2518        UInt64 => Ok(strconv::format_uint64(buf, d.unwrap_uint64())),
2519        Float32 => Ok(strconv::format_float32(buf, d.unwrap_float32())),
2520        Float64 => Ok(strconv::format_float64(buf, d.unwrap_float64())),
2521        Numeric { .. } => Ok(strconv::format_numeric(buf, &d.unwrap_numeric())),
2522        Date => Ok(strconv::format_date(buf, d.unwrap_date())),
2523        Time => Ok(strconv::format_time(buf, d.unwrap_time())),
2524        Timestamp { .. } => Ok(strconv::format_timestamp(buf, &d.unwrap_timestamp())),
2525        TimestampTz { .. } => Ok(strconv::format_timestamptz(buf, &d.unwrap_timestamptz())),
2526        Interval => Ok(strconv::format_interval(buf, d.unwrap_interval())),
2527        Bytes => Ok(strconv::format_bytes(buf, d.unwrap_bytes())),
2528        String | VarChar { .. } | PgLegacyName => Ok(strconv::format_string(buf, d.unwrap_str())),
2529        Char { length } => Ok(strconv::format_string(
2530            buf,
2531            &mz_repr::adt::char::format_str_pad(d.unwrap_str(), *length),
2532        )),
2533        PgLegacyChar => {
2534            format_pg_legacy_char(buf, d.unwrap_uint8())?;
2535            Ok(strconv::Nestable::MayNeedEscaping)
2536        }
2537        Jsonb => Ok(strconv::format_jsonb(buf, JsonbRef::from_datum(d))),
2538        Uuid => Ok(strconv::format_uuid(buf, d.unwrap_uuid())),
2539        Record { fields, .. } => {
2540            let mut fields = fields.iter();
2541            strconv::format_record(buf, d.unwrap_list(), |buf, d| {
2542                let (_name, ty) = fields.next().unwrap();
2543                if d.is_null() {
2544                    Ok(buf.write_null())
2545                } else {
2546                    stringify_datum(buf.nonnull_buffer(), d, &ty.scalar_type)
2547                }
2548            })
2549        }
2550        Array(elem_type) => strconv::format_array(
2551            buf,
2552            &d.unwrap_array().dims().into_iter().collect::<Vec<_>>(),
2553            d.unwrap_array().elements(),
2554            |buf, d| {
2555                if d.is_null() {
2556                    Ok(buf.write_null())
2557                } else {
2558                    stringify_datum(buf.nonnull_buffer(), d, elem_type)
2559                }
2560            },
2561        ),
2562        List { element_type, .. } => strconv::format_list(buf, d.unwrap_list(), |buf, d| {
2563            if d.is_null() {
2564                Ok(buf.write_null())
2565            } else {
2566                stringify_datum(buf.nonnull_buffer(), d, element_type)
2567            }
2568        }),
2569        Map { value_type, .. } => strconv::format_map(buf, &d.unwrap_map(), |buf, d| {
2570            if d.is_null() {
2571                Ok(buf.write_null())
2572            } else {
2573                stringify_datum(buf.nonnull_buffer(), d, value_type)
2574            }
2575        }),
2576        Int2Vector => strconv::format_legacy_vector(buf, d.unwrap_array().elements(), |buf, d| {
2577            stringify_datum(buf.nonnull_buffer(), d, &SqlScalarType::Int16)
2578        }),
2579        MzTimestamp { .. } => Ok(strconv::format_mz_timestamp(buf, d.unwrap_mz_timestamp())),
2580        Range { element_type } => strconv::format_range(buf, &d.unwrap_range(), |buf, d| match d {
2581            Some(d) => stringify_datum(buf.nonnull_buffer(), *d, element_type),
2582            None => Ok::<_, EvalError>(buf.write_null()),
2583        }),
2584        MzAclItem => Ok(strconv::format_mz_acl_item(buf, d.unwrap_mz_acl_item())),
2585    }
2586}
2587
2588#[sqlfunc]
2589fn position(substring: &str, string: &str) -> Result<i32, EvalError> {
2590    let char_index = string.find(substring);
2591
2592    if let Some(char_index) = char_index {
2593        // find the index in char space
2594        let string_prefix = &string[0..char_index];
2595
2596        let num_prefix_chars = string_prefix.chars().count();
2597        let num_prefix_chars = i32::try_from(num_prefix_chars)
2598            .map_err(|_| EvalError::Int32OutOfRange(num_prefix_chars.to_string().into()))?;
2599
2600        Ok(num_prefix_chars + 1)
2601    } else {
2602        Ok(0)
2603    }
2604}
2605
2606#[sqlfunc]
2607fn strpos(string: &str, substring: &str) -> Result<i32, EvalError> {
2608    position(substring, string)
2609}
2610
2611#[sqlfunc(
2612    propagates_nulls = true,
2613    // `left` is unfortunately not monotonic (at least for negative second arguments),
2614    // because 'aa' < 'z', but `left(_, -1)` makes 'a' > ''.
2615    is_monotone = (false, false)
2616)]
2617fn left<'a>(string: &'a str, b: i32) -> Result<&'a str, EvalError> {
2618    let n = i64::from(b);
2619
2620    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2621
2622    let end_in_bytes = match n.cmp(&0) {
2623        Ordering::Equal => 0,
2624        Ordering::Greater => {
2625            let n = usize::try_from(n).map_err(|_| {
2626                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2627            })?;
2628            // nth from the back
2629            byte_indices.nth(n).unwrap_or(string.len())
2630        }
2631        Ordering::Less => {
2632            let n = usize::try_from(n.abs() - 1).map_err(|_| {
2633                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2634            })?;
2635            byte_indices.rev().nth(n).unwrap_or(0)
2636        }
2637    };
2638
2639    Ok(&string[..end_in_bytes])
2640}
2641
2642#[sqlfunc(propagates_nulls = true)]
2643fn right<'a>(string: &'a str, n: i32) -> Result<&'a str, EvalError> {
2644    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2645
2646    let start_in_bytes = if n == 0 {
2647        string.len()
2648    } else if n > 0 {
2649        let n = usize::try_from(n - 1).map_err(|_| {
2650            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2651        })?;
2652        // nth from the back
2653        byte_indices.rev().nth(n).unwrap_or(0)
2654    } else if n == i32::MIN {
2655        // this seems strange but Postgres behaves like this
2656        0
2657    } else {
2658        let n = n.abs();
2659        let n = usize::try_from(n).map_err(|_| {
2660            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2661        })?;
2662        byte_indices.nth(n).unwrap_or(string.len())
2663    };
2664
2665    Ok(&string[start_in_bytes..])
2666}
2667
2668#[sqlfunc(sqlname = "btrim", propagates_nulls = true)]
2669fn trim<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2670    a.trim_matches(|c| trim_chars.contains(c))
2671}
2672
2673#[sqlfunc(sqlname = "ltrim", propagates_nulls = true)]
2674fn trim_leading<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2675    a.trim_start_matches(|c| trim_chars.contains(c))
2676}
2677
2678#[sqlfunc(sqlname = "rtrim", propagates_nulls = true)]
2679fn trim_trailing<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2680    a.trim_end_matches(|c| trim_chars.contains(c))
2681}
2682
2683#[sqlfunc(
2684    sqlname = "array_length",
2685    propagates_nulls = true,
2686    introduces_nulls = true
2687)]
2688fn array_length<'a>(a: Array<'a>, b: i64) -> Result<Option<i32>, EvalError> {
2689    let i = match usize::try_from(b) {
2690        Ok(0) | Err(_) => return Ok(None),
2691        Ok(n) => n - 1,
2692    };
2693    Ok(match a.dims().into_iter().nth(i) {
2694        None => None,
2695        Some(dim) => Some(
2696            dim.length
2697                .try_into()
2698                .map_err(|_| EvalError::Int32OutOfRange(dim.length.to_string().into()))?,
2699        ),
2700    })
2701}
2702
2703#[sqlfunc(
2704    output_type = "Option<i32>",
2705    is_infix_op = true,
2706    sqlname = "array_lower",
2707    propagates_nulls = true,
2708    introduces_nulls = true
2709)]
2710// TODO(benesch): remove potentially dangerous usage of `as`.
2711#[allow(clippy::as_conversions)]
2712fn array_lower<'a>(a: Array<'a>, i: i64) -> Option<i32> {
2713    if i < 1 {
2714        return None;
2715    }
2716    match a.dims().into_iter().nth(i as usize - 1) {
2717        Some(_) => Some(1),
2718        None => None,
2719    }
2720}
2721
2722#[sqlfunc(
2723    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2724    sqlname = "array_remove",
2725    propagates_nulls = false,
2726    introduces_nulls = false
2727)]
2728fn array_remove<'a>(
2729    arr: Array<'a>,
2730    b: Datum<'a>,
2731    temp_storage: &'a RowArena,
2732) -> Result<Datum<'a>, EvalError> {
2733    // Zero-dimensional arrays are empty by definition
2734    if arr.dims().len() == 0 {
2735        return Ok(Datum::Array(arr));
2736    }
2737
2738    // array_remove only supports one-dimensional arrays
2739    if arr.dims().len() > 1 {
2740        return Err(EvalError::MultidimensionalArrayRemovalNotSupported);
2741    }
2742
2743    let elems: Vec<_> = arr.elements().iter().filter(|v| v != &b).collect();
2744    let mut dims = arr.dims().into_iter().collect::<Vec<_>>();
2745    // This access is safe because `dims` is guaranteed to be non-empty
2746    dims[0] = ArrayDimension {
2747        lower_bound: 1,
2748        length: elems.len(),
2749    };
2750
2751    Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?)
2752}
2753
2754#[sqlfunc(
2755    output_type = "Option<i32>",
2756    is_infix_op = true,
2757    sqlname = "array_upper",
2758    propagates_nulls = true,
2759    introduces_nulls = true
2760)]
2761// TODO(benesch): remove potentially dangerous usage of `as`.
2762#[allow(clippy::as_conversions)]
2763fn array_upper<'a>(a: Array<'a>, i: i64) -> Result<Option<i32>, EvalError> {
2764    if i < 1 {
2765        return Ok(None);
2766    }
2767    a.dims()
2768        .into_iter()
2769        .nth(i as usize - 1)
2770        .map(|dim| {
2771            dim.length
2772                .try_into()
2773                .map_err(|_| EvalError::Int32OutOfRange(dim.length.to_string().into()))
2774        })
2775        .transpose()
2776}
2777
2778#[sqlfunc(
2779    is_infix_op = true,
2780    sqlname = "array_contains",
2781    propagates_nulls = true,
2782    introduces_nulls = false
2783)]
2784fn array_contains<'a>(a: Datum<'a>, array: Array<'a>) -> bool {
2785    array.elements().iter().any(|e| e == a)
2786}
2787
2788#[sqlfunc(is_infix_op = true, sqlname = "@>")]
2789fn array_contains_array<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2790    let a = a.elements();
2791    let b = b.elements();
2792
2793    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
2794    if b.iter().contains(&Datum::Null) {
2795        false
2796    } else {
2797        b.iter()
2798            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
2799    }
2800}
2801
2802#[sqlfunc(is_infix_op = true, sqlname = "<@")]
2803fn array_contains_array_rev<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2804    array_contains_array(b, a)
2805}
2806
2807#[sqlfunc(
2808    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2809    is_infix_op = true,
2810    sqlname = "||",
2811    propagates_nulls = false,
2812    introduces_nulls = false
2813)]
2814fn array_array_concat<'a>(
2815    a: Option<Array<'a>>,
2816    b: Option<Array<'a>>,
2817    temp_storage: &'a RowArena,
2818) -> Result<Option<Array<'a>>, EvalError> {
2819    let Some(a_array) = a else {
2820        return Ok(b);
2821    };
2822    let Some(b_array) = b else {
2823        return Ok(a);
2824    };
2825
2826    let a_dims: Vec<ArrayDimension> = a_array.dims().into_iter().collect();
2827    let b_dims: Vec<ArrayDimension> = b_array.dims().into_iter().collect();
2828
2829    let a_ndims = a_dims.len();
2830    let b_ndims = b_dims.len();
2831
2832    // Per PostgreSQL, if either of the input arrays is zero dimensional,
2833    // the output is the other array, no matter their dimensions.
2834    if a_ndims == 0 {
2835        return Ok(b);
2836    } else if b_ndims == 0 {
2837        return Ok(a);
2838    }
2839
2840    // Postgres supports concatenating arrays of different dimensions,
2841    // as long as one of the arrays has the same type as an element of
2842    // the other array, i.e. `int[2][4] || int[4]` (or `int[4] || int[2][4]`)
2843    // works, because each element of `int[2][4]` is an `int[4]`.
2844    // This check is separate from the one below because Postgres gives a
2845    // specific error message if the number of dimensions differs by more
2846    // than one.
2847    // This cast is safe since MAX_ARRAY_DIMENSIONS is 6
2848    // Can be replaced by .abs_diff once it is stabilized
2849    // TODO(benesch): remove potentially dangerous usage of `as`.
2850    #[allow(clippy::as_conversions)]
2851    if (a_ndims as isize - b_ndims as isize).abs() > 1 {
2852        return Err(EvalError::IncompatibleArrayDimensions {
2853            dims: Some((a_ndims, b_ndims)),
2854        });
2855    }
2856
2857    let mut dims;
2858
2859    // After the checks above, we are certain that:
2860    // - neither array is zero dimensional nor empty
2861    // - both arrays have the same number of dimensions, or differ
2862    //   at most by one.
2863    match a_ndims.cmp(&b_ndims) {
2864        // If both arrays have the same number of dimensions, validate
2865        // that their inner dimensions are the same and concatenate the
2866        // arrays.
2867        Ordering::Equal => {
2868            if &a_dims[1..] != &b_dims[1..] {
2869                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2870            }
2871            dims = vec![ArrayDimension {
2872                lower_bound: a_dims[0].lower_bound,
2873                length: a_dims[0].length + b_dims[0].length,
2874            }];
2875            dims.extend(&a_dims[1..]);
2876        }
2877        // If `a` has less dimensions than `b`, this is an element-array
2878        // concatenation, which requires that `a` has the same dimensions
2879        // as an element of `b`.
2880        Ordering::Less => {
2881            if &a_dims[..] != &b_dims[1..] {
2882                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2883            }
2884            dims = vec![ArrayDimension {
2885                lower_bound: b_dims[0].lower_bound,
2886                // Since `a` is treated as an element of `b`, the length of
2887                // the first dimension of `b` is incremented by one, as `a` is
2888                // non-empty.
2889                length: b_dims[0].length + 1,
2890            }];
2891            dims.extend(a_dims);
2892        }
2893        // If `a` has more dimensions than `b`, this is an array-element
2894        // concatenation, which requires that `b` has the same dimensions
2895        // as an element of `a`.
2896        Ordering::Greater => {
2897            if &a_dims[1..] != &b_dims[..] {
2898                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2899            }
2900            dims = vec![ArrayDimension {
2901                lower_bound: a_dims[0].lower_bound,
2902                // Since `b` is treated as an element of `a`, the length of
2903                // the first dimension of `a` is incremented by one, as `b`
2904                // is non-empty.
2905                length: a_dims[0].length + 1,
2906            }];
2907            dims.extend(b_dims);
2908        }
2909    }
2910
2911    let elems = a_array.elements().iter().chain(b_array.elements().iter());
2912
2913    let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?;
2914    Ok(Some(datum.unwrap_array()))
2915}
2916
2917#[sqlfunc(
2918    is_infix_op = true,
2919    sqlname = "||",
2920    propagates_nulls = false,
2921    introduces_nulls = false
2922)]
2923fn list_list_concat<'a, T: FromDatum<'a>>(
2924    a: Option<DatumList<'a, T>>,
2925    b: Option<DatumList<'a, T>>,
2926    temp_storage: &'a RowArena,
2927) -> Option<DatumList<'a, T>> {
2928    let Some(a) = a else {
2929        return b;
2930    };
2931    let Some(b) = b else {
2932        return Some(a);
2933    };
2934
2935    Some(temp_storage.make_datum_list(a.typed_iter().chain(b.typed_iter())))
2936}
2937
2938#[sqlfunc(is_infix_op = true, sqlname = "||", propagates_nulls = false)]
2939fn list_element_concat<'a, T: FromDatum<'a>>(
2940    a: Option<DatumList<'a, T>>,
2941    b: T,
2942    temp_storage: &'a RowArena,
2943) -> DatumList<'a, T> {
2944    let a_elems = a.into_iter().flat_map(|a| a.typed_iter());
2945    temp_storage.make_datum_list(a_elems.chain(std::iter::once(b)))
2946}
2947
2948// Note that the output type corresponds to the _second_ parameter's input type.
2949#[sqlfunc(is_infix_op = true, sqlname = "||", propagates_nulls = false)]
2950fn element_list_concat<'a, T: FromDatum<'a>>(
2951    a: T,
2952    b: Option<DatumList<'a, T>>,
2953    temp_storage: &'a RowArena,
2954) -> DatumList<'a, T> {
2955    let b_elems = b.into_iter().flat_map(|b| b.typed_iter());
2956    temp_storage.make_datum_list(std::iter::once(a).chain(b_elems))
2957}
2958
2959#[sqlfunc(sqlname = "list_remove")]
2960fn list_remove<'a, T: FromDatum<'a>>(
2961    a: DatumList<'a, T>,
2962    b: T,
2963    temp_storage: &'a RowArena,
2964) -> DatumList<'a, T> {
2965    temp_storage.make_datum_list(a.typed_iter().filter(|elem| *elem != b))
2966}
2967
2968#[sqlfunc(sqlname = "digest")]
2969fn digest_string(to_digest: &str, digest_fn: &str) -> Result<Vec<u8>, EvalError> {
2970    digest_inner(to_digest.as_bytes(), digest_fn)
2971}
2972
2973#[sqlfunc(sqlname = "digest")]
2974fn digest_bytes(to_digest: &[u8], digest_fn: &str) -> Result<Vec<u8>, EvalError> {
2975    digest_inner(to_digest, digest_fn)
2976}
2977
2978fn digest_inner(bytes: &[u8], digest_fn: &str) -> Result<Vec<u8>, EvalError> {
2979    match digest_fn {
2980        "md5" => Ok(Md5::digest(bytes).to_vec()),
2981        "sha1" => Ok(digest::digest(&digest::SHA1_FOR_LEGACY_USE_ONLY, bytes)
2982            .as_ref()
2983            .to_vec()),
2984        "sha224" => Ok(digest::digest(&digest::SHA224, bytes).as_ref().to_vec()),
2985        "sha256" => Ok(digest::digest(&digest::SHA256, bytes).as_ref().to_vec()),
2986        "sha384" => Ok(digest::digest(&digest::SHA384, bytes).as_ref().to_vec()),
2987        "sha512" => Ok(digest::digest(&digest::SHA512, bytes).as_ref().to_vec()),
2988        other => Err(EvalError::InvalidHashAlgorithm(other.into())),
2989    }
2990}
2991
2992#[sqlfunc]
2993fn mz_render_typmod(oid: u32, typmod: i32) -> String {
2994    match Type::from_oid_and_typmod(oid, typmod) {
2995        Ok(typ) => typ.constraint().display_or("").to_string(),
2996        // Match dubious PostgreSQL behavior of outputting the unmodified
2997        // `typmod` when positive if the type OID/typmod is invalid.
2998        Err(_) if typmod >= 0 => format!("({typmod})"),
2999        Err(_) => "".into(),
3000    }
3001}
3002
3003#[cfg(test)]
3004mod test {
3005    use chrono::prelude::*;
3006    use mz_repr::PropDatum;
3007    use proptest::prelude::*;
3008
3009    use super::*;
3010    use crate::MirScalarExpr;
3011
3012    #[mz_ore::test]
3013    fn add_interval_months() {
3014        let dt = ym(2000, 1);
3015
3016        assert_eq!(add_timestamp_months(&*dt, 0).unwrap(), dt);
3017        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 2));
3018        assert_eq!(add_timestamp_months(&*dt, 12).unwrap(), ym(2001, 1));
3019        assert_eq!(add_timestamp_months(&*dt, 13).unwrap(), ym(2001, 2));
3020        assert_eq!(add_timestamp_months(&*dt, 24).unwrap(), ym(2002, 1));
3021        assert_eq!(add_timestamp_months(&*dt, 30).unwrap(), ym(2002, 7));
3022
3023        // and negatives
3024        assert_eq!(add_timestamp_months(&*dt, -1).unwrap(), ym(1999, 12));
3025        assert_eq!(add_timestamp_months(&*dt, -12).unwrap(), ym(1999, 1));
3026        assert_eq!(add_timestamp_months(&*dt, -13).unwrap(), ym(1998, 12));
3027        assert_eq!(add_timestamp_months(&*dt, -24).unwrap(), ym(1998, 1));
3028        assert_eq!(add_timestamp_months(&*dt, -30).unwrap(), ym(1997, 7));
3029
3030        // and going over a year boundary by less than a year
3031        let dt = ym(1999, 12);
3032        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 1));
3033        let end_of_month_dt = NaiveDate::from_ymd_opt(1999, 12, 31)
3034            .unwrap()
3035            .and_hms_opt(9, 9, 9)
3036            .unwrap();
3037        assert_eq!(
3038            // leap year
3039            add_timestamp_months(&end_of_month_dt, 2).unwrap(),
3040            NaiveDate::from_ymd_opt(2000, 2, 29)
3041                .unwrap()
3042                .and_hms_opt(9, 9, 9)
3043                .unwrap()
3044                .try_into()
3045                .unwrap(),
3046        );
3047        assert_eq!(
3048            // not leap year
3049            add_timestamp_months(&end_of_month_dt, 14).unwrap(),
3050            NaiveDate::from_ymd_opt(2001, 2, 28)
3051                .unwrap()
3052                .and_hms_opt(9, 9, 9)
3053                .unwrap()
3054                .try_into()
3055                .unwrap(),
3056        );
3057    }
3058
3059    fn ym(year: i32, month: u32) -> CheckedTimestamp<NaiveDateTime> {
3060        NaiveDate::from_ymd_opt(year, month, 1)
3061            .unwrap()
3062            .and_hms_opt(9, 9, 9)
3063            .unwrap()
3064            .try_into()
3065            .unwrap()
3066    }
3067
3068    #[mz_ore::test]
3069    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
3070    fn test_is_monotone() {
3071        use proptest::prelude::*;
3072
3073        /// Asserts that the function is either monotonically increasing or decreasing over
3074        /// the given sets of arguments.
3075        fn assert_monotone<'a, const N: usize>(
3076            expr: &MirScalarExpr,
3077            arena: &'a RowArena,
3078            datums: &[[Datum<'a>; N]],
3079        ) {
3080            // TODO: assertions for nulls, errors
3081            let Ok(results) = datums
3082                .iter()
3083                .map(|args| expr.eval(args.as_slice(), arena))
3084                .collect::<Result<Vec<_>, _>>()
3085            else {
3086                return;
3087            };
3088
3089            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
3090            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
3091            assert!(
3092                forward || reverse,
3093                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
3094            );
3095        }
3096
3097        fn proptest_binary<'a>(
3098            func: BinaryFunc,
3099            arena: &'a RowArena,
3100            left: impl Strategy<Value = PropDatum>,
3101            right: impl Strategy<Value = PropDatum>,
3102        ) {
3103            let (left_monotone, right_monotone) = func.is_monotone();
3104            let expr = MirScalarExpr::CallBinary {
3105                func,
3106                expr1: Box::new(MirScalarExpr::column(0)),
3107                expr2: Box::new(MirScalarExpr::column(1)),
3108            };
3109            proptest!(|(
3110                mut left in proptest::array::uniform3(left),
3111                mut right in proptest::array::uniform3(right),
3112            )| {
3113                left.sort();
3114                right.sort();
3115                if left_monotone {
3116                    for r in &right {
3117                        let args: Vec<[_; 2]> = left
3118                            .iter()
3119                            .map(|l| [Datum::from(l), Datum::from(r)])
3120                            .collect();
3121                        assert_monotone(&expr, arena, &args);
3122                    }
3123                }
3124                if right_monotone {
3125                    for l in &left {
3126                        let args: Vec<[_; 2]> = right
3127                            .iter()
3128                            .map(|r| [Datum::from(l), Datum::from(r)])
3129                            .collect();
3130                        assert_monotone(&expr, arena, &args);
3131                    }
3132                }
3133            });
3134        }
3135
3136        let interesting_strs: Vec<_> = SqlScalarType::String.interesting_datums().collect();
3137        let str_datums = proptest::strategy::Union::new([
3138            proptest::string::string_regex("[A-Z]{0,10}")
3139                .expect("valid regex")
3140                .prop_map(|s| PropDatum::String(s.to_string()))
3141                .boxed(),
3142            (0..interesting_strs.len())
3143                .prop_map(move |i| {
3144                    let Datum::String(val) = interesting_strs[i] else {
3145                        unreachable!("interesting strings has non-strings")
3146                    };
3147                    PropDatum::String(val.to_string())
3148                })
3149                .boxed(),
3150        ]);
3151
3152        let interesting_i32s: Vec<Datum<'static>> =
3153            SqlScalarType::Int32.interesting_datums().collect();
3154        let i32_datums = proptest::strategy::Union::new([
3155            any::<i32>().prop_map(PropDatum::Int32).boxed(),
3156            (0..interesting_i32s.len())
3157                .prop_map(move |i| {
3158                    let Datum::Int32(val) = interesting_i32s[i] else {
3159                        unreachable!("interesting int32 has non-i32s")
3160                    };
3161                    PropDatum::Int32(val)
3162                })
3163                .boxed(),
3164            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
3165        ]);
3166
3167        let arena = RowArena::new();
3168
3169        // It would be interesting to test all funcs here, but we currently need to hardcode
3170        // the generators for the argument types, which makes this tedious. Choose an interesting
3171        // subset for now.
3172        proptest_binary(
3173            BinaryFunc::AddInt32(AddInt32),
3174            &arena,
3175            &i32_datums,
3176            &i32_datums,
3177        );
3178        proptest_binary(SubInt32.into(), &arena, &i32_datums, &i32_datums);
3179        proptest_binary(MulInt32.into(), &arena, &i32_datums, &i32_datums);
3180        proptest_binary(DivInt32.into(), &arena, &i32_datums, &i32_datums);
3181        proptest_binary(TextConcatBinary.into(), &arena, &str_datums, &str_datums);
3182        proptest_binary(Left.into(), &arena, &str_datums, &i32_datums);
3183    }
3184}