Skip to main content

mz_expr/scalar/
func.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14use std::borrow::Cow;
15use std::cmp::Ordering;
16use std::convert::{TryFrom, TryInto};
17use std::str::FromStr;
18use std::{iter, str};
19
20use ::encoding::DecoderTrap;
21use ::encoding::label::encoding_from_whatwg_label;
22use aws_lc_rs::constant_time::verify_slices_are_equal;
23use aws_lc_rs::digest;
24use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, TimeZone, Timelike, Utc};
25use chrono_tz::{OffsetComponents, OffsetName, Tz};
26use dec::OrderedDecimal;
27use itertools::Itertools;
28use md5::{Digest, Md5};
29use mz_expr_derive::sqlfunc;
30use mz_ore::cast::{self, CastFrom};
31use mz_ore::fmt::FormatBuffer;
32use mz_ore::lex::LexBuf;
33use mz_ore::option::OptionExt;
34use mz_pgrepr::Type;
35use mz_pgtz::timezone::{Timezone, TimezoneSpec};
36use mz_repr::adt::array::{Array, ArrayDimension};
37use mz_repr::adt::date::Date;
38use mz_repr::adt::interval::{Interval, RoundBehavior};
39use mz_repr::adt::jsonb::JsonbRef;
40use mz_repr::adt::mz_acl_item::{AclMode, MzAclItem};
41use mz_repr::adt::numeric::{self, Numeric};
42use mz_repr::adt::range::Range;
43use mz_repr::adt::regex::Regex;
44use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
45use mz_repr::{
46    ArrayRustType, Datum, DatumList, DatumMap, ExcludeNull, FromDatum, InputDatumType, Row,
47    RowArena, SqlScalarType, strconv,
48};
49use mz_sql_parser::ast::display::{AstDisplay, FormatMode};
50use mz_sql_pretty::{PrettyConfig, pretty_str};
51use num::traits::CheckedNeg;
52
53use crate::scalar::func::format::DateTimeFormat;
54use crate::{EvalError, like_pattern};
55
56#[macro_use]
57mod macros;
58mod binary;
59mod encoding;
60pub(crate) mod format;
61pub(crate) mod impls;
62mod unary;
63mod unmaterializable;
64pub mod variadic;
65
66pub use binary::BinaryFunc;
67pub use impls::*;
68pub use unary::{EagerUnaryFunc, LazyUnaryFunc, UnaryFunc};
69pub use unmaterializable::UnmaterializableFunc;
70pub use variadic::VariadicFunc;
71
72/// The maximum size of the result strings of certain string functions, such as `repeat` and `lpad`.
73/// Chosen to be the smallest number to keep our tests passing without changing. 100MiB is probably
74/// higher than what we want, but it's better than no limit.
75///
76/// Note: This number appears in our user-facing documentation in the function reference for every
77/// function where it applies.
78pub const MAX_STRING_FUNC_RESULT_BYTES: usize = 1024 * 1024 * 100;
79
80pub fn jsonb_stringify<'a>(a: Datum<'a>, temp_storage: &'a RowArena) -> Option<&'a str> {
81    match a {
82        Datum::JsonNull => None,
83        Datum::String(s) => Some(s),
84        _ => {
85            let s = cast_jsonb_to_string(JsonbRef::from_datum(a));
86            Some(temp_storage.push_string(s))
87        }
88    }
89}
90
91#[sqlfunc(
92    is_monotone = "(true, true)",
93    is_infix_op = true,
94    sqlname = "+",
95    propagates_nulls = true
96)]
97fn add_int16(a: i16, b: i16) -> Result<i16, EvalError> {
98    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
99}
100
101#[sqlfunc(
102    is_monotone = "(true, true)",
103    is_infix_op = true,
104    sqlname = "+",
105    propagates_nulls = true
106)]
107fn add_int32(a: i32, b: i32) -> Result<i32, EvalError> {
108    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
109}
110
111#[sqlfunc(
112    is_monotone = "(true, true)",
113    is_infix_op = true,
114    sqlname = "+",
115    propagates_nulls = true
116)]
117fn add_int64(a: i64, b: i64) -> Result<i64, EvalError> {
118    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
119}
120
121#[sqlfunc(
122    is_monotone = "(true, true)",
123    is_infix_op = true,
124    sqlname = "+",
125    propagates_nulls = true
126)]
127fn add_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
128    a.checked_add(b)
129        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} + {b}").into()))
130}
131
132#[sqlfunc(
133    is_monotone = "(true, true)",
134    is_infix_op = true,
135    sqlname = "+",
136    propagates_nulls = true
137)]
138fn add_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
139    a.checked_add(b)
140        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} + {b}").into()))
141}
142
143#[sqlfunc(
144    is_monotone = "(true, true)",
145    is_infix_op = true,
146    sqlname = "+",
147    propagates_nulls = true
148)]
149fn add_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
150    a.checked_add(b)
151        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} + {b}").into()))
152}
153
154#[sqlfunc(
155    is_monotone = "(true, true)",
156    is_infix_op = true,
157    sqlname = "+",
158    propagates_nulls = true
159)]
160fn add_float32(a: f32, b: f32) -> Result<f32, EvalError> {
161    let sum = a + b;
162    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
163        Err(EvalError::FloatOverflow)
164    } else {
165        Ok(sum)
166    }
167}
168
169#[sqlfunc(
170    is_monotone = "(true, true)",
171    is_infix_op = true,
172    sqlname = "+",
173    propagates_nulls = true
174)]
175fn add_float64(a: f64, b: f64) -> Result<f64, EvalError> {
176    let sum = a + b;
177    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
178        Err(EvalError::FloatOverflow)
179    } else {
180        Ok(sum)
181    }
182}
183
184// `Interval` is lex-ordered (months, days, micros), but adding an interval to a
185// timestamp adds *calendar* months (with day-clamping) which does not respect
186// that ordering: e.g. `i1 = {0 months, 31 days}` is lex-less than
187// `i2 = {1 month, 0 days}`, but `2024-01-31 + i1 = 2024-03-02` is greater than
188// `2024-01-31 + i2 = 2024-02-29`. Day-clamping plus preserved sub-day time also
189// breaks monotonicity in the first argument near month boundaries.
190#[sqlfunc(is_monotone = "(false, false)", is_infix_op = true, sqlname = "+")]
191fn add_timestamp_interval(
192    a: CheckedTimestamp<NaiveDateTime>,
193    b: Interval,
194) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
195    add_timestamplike_interval(a, b)
196}
197
198#[sqlfunc(is_monotone = "(false, false)", is_infix_op = true, sqlname = "+")]
199fn add_timestamp_tz_interval(
200    a: CheckedTimestamp<DateTime<Utc>>,
201    b: Interval,
202) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
203    add_timestamplike_interval(a, b)
204}
205
206fn add_timestamplike_interval<T>(
207    a: CheckedTimestamp<T>,
208    b: Interval,
209) -> Result<CheckedTimestamp<T>, EvalError>
210where
211    T: TimestampLike,
212{
213    let dt = a.date_time();
214    let dt = add_timestamp_months(&dt, b.months)?;
215    let dt = dt
216        .checked_add_signed(b.duration_as_chrono())
217        .ok_or(EvalError::TimestampOutOfRange)?;
218    Ok(CheckedTimestamp::from_timestamplike(T::from_date_time(dt))?)
219}
220
221// See `add_timestamp_interval` for why this is not monotone.
222#[sqlfunc(is_monotone = "(false, false)", is_infix_op = true, sqlname = "-")]
223fn sub_timestamp_interval(
224    a: CheckedTimestamp<NaiveDateTime>,
225    b: Interval,
226) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
227    sub_timestamplike_interval(a, b)
228}
229
230#[sqlfunc(is_monotone = "(false, false)", is_infix_op = true, sqlname = "-")]
231fn sub_timestamp_tz_interval(
232    a: CheckedTimestamp<DateTime<Utc>>,
233    b: Interval,
234) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
235    sub_timestamplike_interval(a, b)
236}
237
238fn sub_timestamplike_interval<T>(
239    a: CheckedTimestamp<T>,
240    b: Interval,
241) -> Result<CheckedTimestamp<T>, EvalError>
242where
243    T: TimestampLike,
244{
245    neg_interval_inner(b).and_then(|i| add_timestamplike_interval(a, i))
246}
247
248#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
249fn add_date_time(
250    date: Date,
251    time: chrono::NaiveTime,
252) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
253    let dt = NaiveDate::from(date)
254        .and_hms_nano_opt(time.hour(), time.minute(), time.second(), time.nanosecond())
255        .unwrap();
256    Ok(CheckedTimestamp::from_timestamplike(dt)?)
257}
258
259// Monotone in `date` (dates have no sub-day component, so day-clamping at month
260// boundaries only causes results to collapse, never to reverse), but not in
261// `interval`: e.g. `{0 months, 31 days}` is lex-less than `{1 month, 0 days}`,
262// but adding the former to `2024-01-31` gives `2024-03-02` while the latter
263// gives `2024-02-29`.
264#[sqlfunc(is_monotone = "(true, false)", is_infix_op = true, sqlname = "+")]
265fn add_date_interval(
266    date: Date,
267    interval: Interval,
268) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
269    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
270    let dt = add_timestamp_months(&dt, interval.months)?;
271    let dt = dt
272        .checked_add_signed(interval.duration_as_chrono())
273        .ok_or(EvalError::TimestampOutOfRange)?;
274    Ok(CheckedTimestamp::from_timestamplike(dt)?)
275}
276
277#[sqlfunc(
278    // <time> + <interval> wraps!
279    is_monotone = "(false, false)",
280    is_infix_op = true,
281    sqlname = "+",
282    propagates_nulls = true
283)]
284fn add_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
285    let (t, _) = time.overflowing_add_signed(interval.duration_as_chrono());
286    t
287}
288
289#[sqlfunc(
290    is_monotone = "(true, false)",
291    output_type = "Numeric",
292    sqlname = "round",
293    propagates_nulls = true
294)]
295fn round_numeric_binary(a: OrderedDecimal<Numeric>, mut b: i32) -> Result<Numeric, EvalError> {
296    let mut a = a.0;
297    let mut cx = numeric::cx_datum();
298    let a_exp = a.exponent();
299    if a_exp > 0 && b > 0 || a_exp < 0 && -a_exp < b {
300        // This condition indicates:
301        // - a is a value without a decimal point, b is a positive number
302        // - a has a decimal point, but b is larger than its scale
303        // In both of these situations, right-pad the number with zeroes, which // is most easily done with rescale.
304
305        // Ensure rescale doesn't exceed max precision by putting a ceiling on
306        // b equal to the maximum remaining scale the value can support.
307        let max_remaining_scale = u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
308            - (numeric::get_precision(&a) - numeric::get_scale(&a));
309        b = match i32::try_from(max_remaining_scale) {
310            Ok(max_remaining_scale) => std::cmp::min(b, max_remaining_scale),
311            Err(_) => b,
312        };
313        cx.rescale(&mut a, &numeric::Numeric::from(-b));
314    } else {
315        // To avoid invalid operations, clamp b to be within 1 more than the
316        // precision limit.
317        const MAX_P_LIMIT: i32 = 1 + cast::u8_to_i32(numeric::NUMERIC_DATUM_MAX_PRECISION);
318        b = std::cmp::min(MAX_P_LIMIT, b);
319        b = std::cmp::max(-MAX_P_LIMIT, b);
320        let mut b = numeric::Numeric::from(b);
321        // Shift by 10^b; this put digit to round to in the one's place.
322        cx.scaleb(&mut a, &b);
323        cx.round(&mut a);
324        // Negate exponent for shift back
325        cx.neg(&mut b);
326        cx.scaleb(&mut a, &b);
327    }
328
329    if cx.status().overflow() {
330        Err(EvalError::FloatOverflow)
331    } else if a.is_zero() {
332        // simpler than handling cases where exponent has gotten set to some
333        // value greater than the max precision, but all significant digits
334        // were rounded away.
335        Ok(numeric::Numeric::zero())
336    } else {
337        numeric::munge_numeric(&mut a).unwrap();
338        Ok(a)
339    }
340}
341
342#[sqlfunc(sqlname = "convert_from", propagates_nulls = true)]
343fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
344    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
345    // which the encoding library uses[3].
346    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
347    // [2]: https://encoding.spec.whatwg.org/
348    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
349    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
350
351    // Supporting other encodings is tracked by database-issues#797.
352    if encoding_from_whatwg_label(&encoding_name).map(|e| e.name()) != Some("utf-8") {
353        return Err(EvalError::InvalidEncodingName(encoding_name));
354    }
355
356    match str::from_utf8(a) {
357        Ok(from) => Ok(from),
358        Err(e) => Err(EvalError::InvalidByteSequence {
359            byte_sequence: e.to_string().into(),
360            encoding_name,
361        }),
362    }
363}
364
365#[sqlfunc]
366fn encode(bytes: &[u8], format: &str) -> Result<String, EvalError> {
367    let format = encoding::lookup_format(format)?;
368    Ok(format.encode(bytes))
369}
370
371#[sqlfunc]
372fn decode(string: &str, format: &str) -> Result<Vec<u8>, EvalError> {
373    let format = encoding::lookup_format(format)?;
374    let out = format.decode(string)?;
375    if out.len() > MAX_STRING_FUNC_RESULT_BYTES {
376        Err(EvalError::LengthTooLarge)
377    } else {
378        Ok(out)
379    }
380}
381
382#[sqlfunc(sqlname = "length", propagates_nulls = true)]
383fn encoded_bytes_char_length(a: &[u8], b: &str) -> Result<i32, EvalError> {
384    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
385    // which the encoding library uses[3].
386    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
387    // [2]: https://encoding.spec.whatwg.org/
388    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
389    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
390
391    let enc = match encoding_from_whatwg_label(&encoding_name) {
392        Some(enc) => enc,
393        None => return Err(EvalError::InvalidEncodingName(encoding_name)),
394    };
395
396    let decoded_string = match enc.decode(a, DecoderTrap::Strict) {
397        Ok(s) => s,
398        Err(e) => {
399            return Err(EvalError::InvalidByteSequence {
400                byte_sequence: e.into(),
401                encoding_name,
402            });
403        }
404    };
405
406    let count = decoded_string.chars().count();
407    i32::try_from(count).map_err(|_| EvalError::Int32OutOfRange(count.to_string().into()))
408}
409
410// TODO(benesch): remove potentially dangerous usage of `as`.
411#[allow(clippy::as_conversions)]
412pub fn add_timestamp_months<T: TimestampLike>(
413    dt: &T,
414    mut months: i32,
415) -> Result<CheckedTimestamp<T>, EvalError> {
416    if months == 0 {
417        return Ok(CheckedTimestamp::from_timestamplike(dt.clone())?);
418    }
419
420    let (mut year, mut month, mut day) = (dt.year(), dt.month0() as i32, dt.day());
421    let years = months / 12;
422    year = year
423        .checked_add(years)
424        .ok_or(EvalError::TimestampOutOfRange)?;
425
426    months %= 12;
427    // positive modulus is easier to reason about
428    if months < 0 {
429        year -= 1;
430        months += 12;
431    }
432    year += (month + months) / 12;
433    month = (month + months) % 12;
434    // account for dt.month0
435    month += 1;
436
437    // handle going from January 31st to February by saturation
438    let mut new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
439    while new_d.is_none() {
440        // If we have decremented day past 28 and are still receiving `None`,
441        // then we have generally overflowed `NaiveDate`.
442        if day < 28 {
443            return Err(EvalError::TimestampOutOfRange);
444        }
445        day -= 1;
446        new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
447    }
448    let new_d = new_d.unwrap();
449
450    // Neither postgres nor mysql support leap seconds, so this should be safe.
451    //
452    // Both my testing and https://dba.stackexchange.com/a/105829 support the
453    // idea that we should ignore leap seconds
454    let new_dt = new_d
455        .and_hms_nano_opt(dt.hour(), dt.minute(), dt.second(), dt.nanosecond())
456        .unwrap();
457    let new_dt = T::from_date_time(new_dt);
458    Ok(CheckedTimestamp::from_timestamplike(new_dt)?)
459}
460
461#[sqlfunc(
462    is_monotone = "(true, true)",
463    is_infix_op = true,
464    sqlname = "+",
465    propagates_nulls = true
466)]
467fn add_numeric(
468    a: OrderedDecimal<Numeric>,
469    b: OrderedDecimal<Numeric>,
470) -> Result<Numeric, EvalError> {
471    let mut cx = numeric::cx_datum();
472    let mut a = a.0;
473    cx.add(&mut a, &b.0);
474    if cx.status().overflow() {
475        Err(EvalError::FloatOverflow)
476    } else {
477        Ok(a)
478    }
479}
480
481#[sqlfunc(
482    is_monotone = "(true, true)",
483    is_infix_op = true,
484    sqlname = "+",
485    propagates_nulls = true
486)]
487fn add_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
488    a.checked_add(&b)
489        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} + {b}").into()))
490}
491
492#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
493fn bit_and_int16(a: i16, b: i16) -> i16 {
494    a & b
495}
496
497#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
498fn bit_and_int32(a: i32, b: i32) -> i32 {
499    a & b
500}
501
502#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
503fn bit_and_int64(a: i64, b: i64) -> i64 {
504    a & b
505}
506
507#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
508fn bit_and_uint16(a: u16, b: u16) -> u16 {
509    a & b
510}
511
512#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
513fn bit_and_uint32(a: u32, b: u32) -> u32 {
514    a & b
515}
516
517#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
518fn bit_and_uint64(a: u64, b: u64) -> u64 {
519    a & b
520}
521
522#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
523fn bit_or_int16(a: i16, b: i16) -> i16 {
524    a | b
525}
526
527#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
528fn bit_or_int32(a: i32, b: i32) -> i32 {
529    a | b
530}
531
532#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
533fn bit_or_int64(a: i64, b: i64) -> i64 {
534    a | b
535}
536
537#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
538fn bit_or_uint16(a: u16, b: u16) -> u16 {
539    a | b
540}
541
542#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
543fn bit_or_uint32(a: u32, b: u32) -> u32 {
544    a | b
545}
546
547#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
548fn bit_or_uint64(a: u64, b: u64) -> u64 {
549    a | b
550}
551
552#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
553fn bit_xor_int16(a: i16, b: i16) -> i16 {
554    a ^ b
555}
556
557#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
558fn bit_xor_int32(a: i32, b: i32) -> i32 {
559    a ^ b
560}
561
562#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
563fn bit_xor_int64(a: i64, b: i64) -> i64 {
564    a ^ b
565}
566
567#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
568fn bit_xor_uint16(a: u16, b: u16) -> u16 {
569    a ^ b
570}
571
572#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
573fn bit_xor_uint32(a: u32, b: u32) -> u32 {
574    a ^ b
575}
576
577#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
578fn bit_xor_uint64(a: u64, b: u64) -> u64 {
579    a ^ b
580}
581
582#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
583// TODO(benesch): remove potentially dangerous usage of `as`.
584#[allow(clippy::as_conversions)]
585fn bit_shift_left_int16(a: i16, b: i32) -> i16 {
586    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
587    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
588    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
589    let lhs: i32 = a as i32;
590    let rhs: u32 = b as u32;
591    lhs.wrapping_shl(rhs) as i16
592}
593
594#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
595// TODO(benesch): remove potentially dangerous usage of `as`.
596#[allow(clippy::as_conversions)]
597fn bit_shift_left_int32(lhs: i32, rhs: i32) -> i32 {
598    let rhs = rhs as u32;
599    lhs.wrapping_shl(rhs)
600}
601
602#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
603// TODO(benesch): remove potentially dangerous usage of `as`.
604#[allow(clippy::as_conversions)]
605fn bit_shift_left_int64(lhs: i64, rhs: i32) -> i64 {
606    let rhs = rhs as u32;
607    lhs.wrapping_shl(rhs)
608}
609
610#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
611// TODO(benesch): remove potentially dangerous usage of `as`.
612#[allow(clippy::as_conversions)]
613fn bit_shift_left_uint16(a: u16, b: u32) -> u16 {
614    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
615    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
616    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
617    let lhs: u32 = a as u32;
618    let rhs: u32 = b;
619    lhs.wrapping_shl(rhs) as u16
620}
621
622#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
623fn bit_shift_left_uint32(a: u32, b: u32) -> u32 {
624    let lhs = a;
625    let rhs = b;
626    lhs.wrapping_shl(rhs)
627}
628
629#[sqlfunc(
630    output_type = "u64",
631    is_infix_op = true,
632    sqlname = "<<",
633    propagates_nulls = true
634)]
635fn bit_shift_left_uint64(lhs: u64, rhs: u32) -> u64 {
636    lhs.wrapping_shl(rhs)
637}
638
639#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
640// TODO(benesch): remove potentially dangerous usage of `as`.
641#[allow(clippy::as_conversions)]
642fn bit_shift_right_int16(lhs: i16, rhs: i32) -> i16 {
643    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
644    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
645    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
646    let lhs = lhs as i32;
647    let rhs = rhs as u32;
648    lhs.wrapping_shr(rhs) as i16
649}
650
651#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
652// TODO(benesch): remove potentially dangerous usage of `as`.
653#[allow(clippy::as_conversions)]
654fn bit_shift_right_int32(lhs: i32, rhs: i32) -> i32 {
655    lhs.wrapping_shr(rhs as u32)
656}
657
658#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
659// TODO(benesch): remove potentially dangerous usage of `as`.
660#[allow(clippy::as_conversions)]
661fn bit_shift_right_int64(lhs: i64, rhs: i32) -> i64 {
662    lhs.wrapping_shr(rhs as u32)
663}
664
665#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
666// TODO(benesch): remove potentially dangerous usage of `as`.
667#[allow(clippy::as_conversions)]
668fn bit_shift_right_uint16(lhs: u16, rhs: u32) -> u16 {
669    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
670    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
671    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
672    let lhs = lhs as u32;
673    lhs.wrapping_shr(rhs) as u16
674}
675
676#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
677fn bit_shift_right_uint32(lhs: u32, rhs: u32) -> u32 {
678    lhs.wrapping_shr(rhs)
679}
680
681#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
682fn bit_shift_right_uint64(lhs: u64, rhs: u32) -> u64 {
683    lhs.wrapping_shr(rhs)
684}
685
686#[sqlfunc(
687    is_monotone = "(true, true)",
688    is_infix_op = true,
689    sqlname = "-",
690    propagates_nulls = true
691)]
692fn sub_int16(a: i16, b: i16) -> Result<i16, EvalError> {
693    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
694}
695
696#[sqlfunc(
697    is_monotone = "(true, true)",
698    is_infix_op = true,
699    sqlname = "-",
700    propagates_nulls = true
701)]
702fn sub_int32(a: i32, b: i32) -> Result<i32, EvalError> {
703    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
704}
705
706#[sqlfunc(
707    is_monotone = "(true, true)",
708    is_infix_op = true,
709    sqlname = "-",
710    propagates_nulls = true
711)]
712fn sub_int64(a: i64, b: i64) -> Result<i64, EvalError> {
713    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
714}
715
716#[sqlfunc(
717    is_monotone = "(true, true)",
718    is_infix_op = true,
719    sqlname = "-",
720    propagates_nulls = true
721)]
722fn sub_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
723    a.checked_sub(b)
724        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} - {b}").into()))
725}
726
727#[sqlfunc(
728    is_monotone = "(true, true)",
729    is_infix_op = true,
730    sqlname = "-",
731    propagates_nulls = true
732)]
733fn sub_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
734    a.checked_sub(b)
735        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} - {b}").into()))
736}
737
738#[sqlfunc(
739    is_monotone = "(true, true)",
740    is_infix_op = true,
741    sqlname = "-",
742    propagates_nulls = true
743)]
744fn sub_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
745    a.checked_sub(b)
746        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} - {b}").into()))
747}
748
749#[sqlfunc(
750    is_monotone = "(true, true)",
751    is_infix_op = true,
752    sqlname = "-",
753    propagates_nulls = true
754)]
755fn sub_float32(a: f32, b: f32) -> Result<f32, EvalError> {
756    let difference = a - b;
757    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
758        Err(EvalError::FloatOverflow)
759    } else {
760        Ok(difference)
761    }
762}
763
764#[sqlfunc(
765    is_monotone = "(true, true)",
766    is_infix_op = true,
767    sqlname = "-",
768    propagates_nulls = true
769)]
770fn sub_float64(a: f64, b: f64) -> Result<f64, EvalError> {
771    let difference = a - b;
772    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
773        Err(EvalError::FloatOverflow)
774    } else {
775        Ok(difference)
776    }
777}
778
779#[sqlfunc(
780    is_monotone = "(true, true)",
781    is_infix_op = true,
782    sqlname = "-",
783    propagates_nulls = true
784)]
785fn sub_numeric(
786    a: OrderedDecimal<Numeric>,
787    b: OrderedDecimal<Numeric>,
788) -> Result<Numeric, EvalError> {
789    let mut cx = numeric::cx_datum();
790    let mut a = a.0;
791    cx.sub(&mut a, &b.0);
792    if cx.status().overflow() {
793        Err(EvalError::FloatOverflow)
794    } else {
795        Ok(a)
796    }
797}
798
799// `age(a, b)` is non-monotone in *both* arguments:
800//
801// * Lex order on `Interval` is `(months, days, micros)`, but the Postgres
802//   `age` algorithm independently subtracts year/month/day/... fields and
803//   then *borrows* across boundaries when a lower field goes negative. With
804//   `b = 2024-02-15` fixed:
805//     a = 2024-03-31  →  age = {1 month, 16 days}
806//     a = 2024-04-01  →  age = {1 month, 15 days}
807//     a = 2024-05-01  →  age = {2 months, 15 days}
808//   As `a` increases past a month boundary, `months` jumps by 1 and `days`
809//   drops, producing a lex-smaller interval than the previous step.
810//
811// * Holding `a` fixed and varying `b`, the result has a V-shape at `a == b`
812//   (sign is flipped when `a < b`):
813//     a = 2024-02-15, b = 2024-02-14  →  age = {0 months, 1 day}
814//     a = 2024-02-15, b = 2024-02-15  →  age = {0 months, 0 days}
815//     a = 2024-02-15, b = 2024-02-16  →  age = {0 months, 1 day}
816#[sqlfunc(sqlname = "age")]
817fn age_timestamp(
818    a: CheckedTimestamp<chrono::NaiveDateTime>,
819    b: CheckedTimestamp<chrono::NaiveDateTime>,
820) -> Result<Interval, EvalError> {
821    Ok(a.age(&b)?)
822}
823
824// See `age_timestamp` for why this is not monotone in either argument.
825#[sqlfunc(sqlname = "age")]
826fn age_timestamp_tz(
827    a: CheckedTimestamp<chrono::DateTime<Utc>>,
828    b: CheckedTimestamp<chrono::DateTime<Utc>>,
829) -> Result<Interval, EvalError> {
830    Ok(a.age(&b)?)
831}
832
833#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
834fn sub_timestamp(
835    a: CheckedTimestamp<NaiveDateTime>,
836    b: CheckedTimestamp<NaiveDateTime>,
837) -> Result<Interval, EvalError> {
838    Interval::from_chrono_duration(a - b)
839        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
840}
841
842#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
843fn sub_timestamp_tz(
844    a: CheckedTimestamp<chrono::DateTime<Utc>>,
845    b: CheckedTimestamp<chrono::DateTime<Utc>>,
846) -> Result<Interval, EvalError> {
847    Interval::from_chrono_duration(a - b)
848        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
849}
850
851#[sqlfunc(
852    is_monotone = "(true, true)",
853    is_infix_op = true,
854    sqlname = "-",
855    propagates_nulls = true
856)]
857fn sub_date(a: Date, b: Date) -> i32 {
858    a - b
859}
860
861#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
862fn sub_time(a: chrono::NaiveTime, b: chrono::NaiveTime) -> Result<Interval, EvalError> {
863    Interval::from_chrono_duration(a - b)
864        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
865}
866
867#[sqlfunc(
868    is_monotone = "(true, true)",
869    output_type = "Interval",
870    is_infix_op = true,
871    sqlname = "-",
872    propagates_nulls = true
873)]
874fn sub_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
875    b.checked_neg()
876        .and_then(|b| b.checked_add(&a))
877        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} - {b}").into()))
878}
879
880// See `add_date_interval` for why this is not monotone in `interval`.
881#[sqlfunc(
882    is_monotone = "(true, false)",
883    is_infix_op = true,
884    sqlname = "-",
885    propagates_nulls = true
886)]
887fn sub_date_interval(
888    date: Date,
889    interval: Interval,
890) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
891    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
892    let dt = interval
893        .months
894        .checked_neg()
895        .ok_or_else(|| EvalError::IntervalOutOfRange(interval.months.to_string().into()))
896        .and_then(|months| add_timestamp_months(&dt, months))?;
897    let dt = dt
898        .checked_sub_signed(interval.duration_as_chrono())
899        .ok_or(EvalError::TimestampOutOfRange)?;
900    Ok(dt.try_into()?)
901}
902
903#[sqlfunc(
904    is_monotone = "(false, false)",
905    is_infix_op = true,
906    sqlname = "-",
907    propagates_nulls = true
908)]
909fn sub_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
910    let (t, _) = time.overflowing_sub_signed(interval.duration_as_chrono());
911    t
912}
913
914#[sqlfunc(
915    is_monotone = "(true, true)",
916    is_infix_op = true,
917    sqlname = "*",
918    propagates_nulls = true
919)]
920fn mul_int16(a: i16, b: i16) -> Result<i16, EvalError> {
921    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
922}
923
924#[sqlfunc(
925    is_monotone = "(true, true)",
926    is_infix_op = true,
927    sqlname = "*",
928    propagates_nulls = true
929)]
930fn mul_int32(a: i32, b: i32) -> Result<i32, EvalError> {
931    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
932}
933
934#[sqlfunc(
935    is_monotone = "(true, true)",
936    is_infix_op = true,
937    sqlname = "*",
938    propagates_nulls = true
939)]
940fn mul_int64(a: i64, b: i64) -> Result<i64, EvalError> {
941    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
942}
943
944#[sqlfunc(
945    is_monotone = "(true, true)",
946    is_infix_op = true,
947    sqlname = "*",
948    propagates_nulls = true
949)]
950fn mul_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
951    a.checked_mul(b)
952        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} * {b}").into()))
953}
954
955#[sqlfunc(
956    is_monotone = "(true, true)",
957    is_infix_op = true,
958    sqlname = "*",
959    propagates_nulls = true
960)]
961fn mul_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
962    a.checked_mul(b)
963        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} * {b}").into()))
964}
965
966#[sqlfunc(
967    is_monotone = "(true, true)",
968    is_infix_op = true,
969    sqlname = "*",
970    propagates_nulls = true
971)]
972fn mul_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
973    a.checked_mul(b)
974        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} * {b}").into()))
975}
976
977#[sqlfunc(
978    is_monotone = (true, true),
979    is_infix_op = true,
980    sqlname = "*",
981    propagates_nulls = true
982)]
983fn mul_float32(a: f32, b: f32) -> Result<f32, EvalError> {
984    let product = a * b;
985    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
986        Err(EvalError::FloatOverflow)
987    } else if product == 0.0f32 && a != 0.0f32 && b != 0.0f32 {
988        Err(EvalError::FloatUnderflow)
989    } else {
990        Ok(product)
991    }
992}
993
994#[sqlfunc(
995    is_monotone = "(true, true)",
996    is_infix_op = true,
997    sqlname = "*",
998    propagates_nulls = true
999)]
1000fn mul_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1001    let product = a * b;
1002    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
1003        Err(EvalError::FloatOverflow)
1004    } else if product == 0.0f64 && a != 0.0f64 && b != 0.0f64 {
1005        Err(EvalError::FloatUnderflow)
1006    } else {
1007        Ok(product)
1008    }
1009}
1010
1011#[sqlfunc(
1012    is_monotone = "(true, true)",
1013    is_infix_op = true,
1014    sqlname = "*",
1015    propagates_nulls = true
1016)]
1017fn mul_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1018    let mut cx = numeric::cx_datum();
1019    cx.mul(&mut a, &b);
1020    let cx_status = cx.status();
1021    if cx_status.overflow() {
1022        Err(EvalError::FloatOverflow)
1023    } else if cx_status.subnormal() {
1024        Err(EvalError::FloatUnderflow)
1025    } else {
1026        numeric::munge_numeric(&mut a).unwrap();
1027        Ok(a)
1028    }
1029}
1030
1031#[sqlfunc(
1032    is_monotone = "(false, false)",
1033    is_infix_op = true,
1034    sqlname = "*",
1035    propagates_nulls = true
1036)]
1037fn mul_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1038    a.checked_mul(b)
1039        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} * {b}").into()))
1040}
1041
1042#[sqlfunc(
1043    is_monotone = "(true, false)",
1044    is_infix_op = true,
1045    sqlname = "/",
1046    propagates_nulls = true
1047)]
1048fn div_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1049    if b == 0 {
1050        Err(EvalError::DivisionByZero)
1051    } else {
1052        a.checked_div(b)
1053            .ok_or_else(|| EvalError::Int16OutOfRange(format!("{a} / {b}").into()))
1054    }
1055}
1056
1057#[sqlfunc(
1058    is_monotone = "(true, false)",
1059    is_infix_op = true,
1060    sqlname = "/",
1061    propagates_nulls = true
1062)]
1063fn div_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1064    if b == 0 {
1065        Err(EvalError::DivisionByZero)
1066    } else {
1067        a.checked_div(b)
1068            .ok_or_else(|| EvalError::Int32OutOfRange(format!("{a} / {b}").into()))
1069    }
1070}
1071
1072#[sqlfunc(
1073    is_monotone = "(true, false)",
1074    is_infix_op = true,
1075    sqlname = "/",
1076    propagates_nulls = true
1077)]
1078fn div_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1079    if b == 0 {
1080        Err(EvalError::DivisionByZero)
1081    } else {
1082        a.checked_div(b)
1083            .ok_or_else(|| EvalError::Int64OutOfRange(format!("{a} / {b}").into()))
1084    }
1085}
1086
1087#[sqlfunc(
1088    is_monotone = "(true, false)",
1089    is_infix_op = true,
1090    sqlname = "/",
1091    propagates_nulls = true
1092)]
1093fn div_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1094    if b == 0 {
1095        Err(EvalError::DivisionByZero)
1096    } else {
1097        Ok(a / b)
1098    }
1099}
1100
1101#[sqlfunc(
1102    is_monotone = "(true, false)",
1103    is_infix_op = true,
1104    sqlname = "/",
1105    propagates_nulls = true
1106)]
1107fn div_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1108    if b == 0 {
1109        Err(EvalError::DivisionByZero)
1110    } else {
1111        Ok(a / b)
1112    }
1113}
1114
1115#[sqlfunc(
1116    is_monotone = "(true, false)",
1117    is_infix_op = true,
1118    sqlname = "/",
1119    propagates_nulls = true
1120)]
1121fn div_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1122    if b == 0 {
1123        Err(EvalError::DivisionByZero)
1124    } else {
1125        Ok(a / b)
1126    }
1127}
1128
1129#[sqlfunc(
1130    is_monotone = "(true, false)",
1131    is_infix_op = true,
1132    sqlname = "/",
1133    propagates_nulls = true
1134)]
1135fn div_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1136    if b == 0.0f32 && !a.is_nan() {
1137        Err(EvalError::DivisionByZero)
1138    } else {
1139        let quotient = a / b;
1140        if quotient.is_infinite() && !a.is_infinite() {
1141            Err(EvalError::FloatOverflow)
1142        } else if quotient == 0.0f32 && a != 0.0f32 && !b.is_infinite() {
1143            Err(EvalError::FloatUnderflow)
1144        } else {
1145            Ok(quotient)
1146        }
1147    }
1148}
1149
1150#[sqlfunc(
1151    is_monotone = "(true, false)",
1152    is_infix_op = true,
1153    sqlname = "/",
1154    propagates_nulls = true
1155)]
1156fn div_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1157    if b == 0.0f64 && !a.is_nan() {
1158        Err(EvalError::DivisionByZero)
1159    } else {
1160        let quotient = a / b;
1161        if quotient.is_infinite() && !a.is_infinite() {
1162            Err(EvalError::FloatOverflow)
1163        } else if quotient == 0.0f64 && a != 0.0f64 && !b.is_infinite() {
1164            Err(EvalError::FloatUnderflow)
1165        } else {
1166            Ok(quotient)
1167        }
1168    }
1169}
1170
1171#[sqlfunc(
1172    is_monotone = "(true, false)",
1173    is_infix_op = true,
1174    sqlname = "/",
1175    propagates_nulls = true
1176)]
1177fn div_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1178    let mut cx = numeric::cx_datum();
1179
1180    cx.div(&mut a, &b);
1181    let cx_status = cx.status();
1182
1183    // checking the status for division by zero errors is insufficient because
1184    // the underlying library treats 0/0 as undefined and not division by zero.
1185    if b.is_zero() {
1186        Err(EvalError::DivisionByZero)
1187    } else if cx_status.overflow() {
1188        Err(EvalError::FloatOverflow)
1189    } else if cx_status.subnormal() {
1190        Err(EvalError::FloatUnderflow)
1191    } else {
1192        numeric::munge_numeric(&mut a).unwrap();
1193        Ok(a)
1194    }
1195}
1196
1197#[sqlfunc(
1198    is_monotone = "(false, false)",
1199    is_infix_op = true,
1200    sqlname = "/",
1201    propagates_nulls = true
1202)]
1203fn div_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1204    if b == 0.0 {
1205        Err(EvalError::DivisionByZero)
1206    } else {
1207        a.checked_div(b)
1208            .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} / {b}").into()))
1209    }
1210}
1211
1212#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1213fn mod_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1214    if b == 0 {
1215        Err(EvalError::DivisionByZero)
1216    } else {
1217        Ok(a.checked_rem(b).unwrap_or(0))
1218    }
1219}
1220
1221#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1222fn mod_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1223    if b == 0 {
1224        Err(EvalError::DivisionByZero)
1225    } else {
1226        Ok(a.checked_rem(b).unwrap_or(0))
1227    }
1228}
1229
1230#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1231fn mod_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1232    if b == 0 {
1233        Err(EvalError::DivisionByZero)
1234    } else {
1235        Ok(a.checked_rem(b).unwrap_or(0))
1236    }
1237}
1238
1239#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1240fn mod_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1241    if b == 0 {
1242        Err(EvalError::DivisionByZero)
1243    } else {
1244        Ok(a % b)
1245    }
1246}
1247
1248#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1249fn mod_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1250    if b == 0 {
1251        Err(EvalError::DivisionByZero)
1252    } else {
1253        Ok(a % b)
1254    }
1255}
1256
1257#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1258fn mod_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1259    if b == 0 {
1260        Err(EvalError::DivisionByZero)
1261    } else {
1262        Ok(a % b)
1263    }
1264}
1265
1266#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1267fn mod_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1268    if b == 0.0 {
1269        Err(EvalError::DivisionByZero)
1270    } else {
1271        Ok(a % b)
1272    }
1273}
1274
1275#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1276fn mod_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1277    if b == 0.0 {
1278        Err(EvalError::DivisionByZero)
1279    } else {
1280        Ok(a % b)
1281    }
1282}
1283
1284#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1285fn mod_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1286    if b.is_zero() {
1287        return Err(EvalError::DivisionByZero);
1288    }
1289    let mut cx = numeric::cx_datum();
1290    // Postgres does _not_ use IEEE 754-style remainder
1291    cx.rem(&mut a, &b);
1292    numeric::munge_numeric(&mut a).unwrap();
1293    Ok(a)
1294}
1295
1296fn neg_interval_inner(a: Interval) -> Result<Interval, EvalError> {
1297    a.checked_neg()
1298        .ok_or_else(|| EvalError::IntervalOutOfRange(a.to_string().into()))
1299}
1300
1301fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
1302    if val.is_negative() {
1303        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
1304    }
1305    if val.is_zero() {
1306        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
1307    }
1308    Ok(())
1309}
1310
1311#[sqlfunc(sqlname = "log", propagates_nulls = true)]
1312fn log_base_numeric(mut a: Numeric, mut b: Numeric) -> Result<Numeric, EvalError> {
1313    log_guard_numeric(&a, "log")?;
1314    log_guard_numeric(&b, "log")?;
1315    let mut cx = numeric::cx_datum();
1316    cx.ln(&mut a);
1317    cx.ln(&mut b);
1318    cx.div(&mut b, &a);
1319    if a.is_zero() {
1320        Err(EvalError::DivisionByZero)
1321    } else {
1322        // This division can result in slightly wrong answers due to the
1323        // limitation of dividing irrational numbers. To correct that, see if
1324        // rounding off the value from its `numeric::NUMERIC_DATUM_MAX_PRECISION
1325        // - 1`th position results in an integral value.
1326        cx.set_precision(usize::from(numeric::NUMERIC_DATUM_MAX_PRECISION - 1))
1327            .expect("reducing precision below max always succeeds");
1328        let mut integral_check = b.clone();
1329
1330        // `reduce` rounds to the context's final digit when the number of
1331        // digits in its argument exceeds its precision. We've contrived that to
1332        // happen by shrinking the context's precision by 1.
1333        cx.reduce(&mut integral_check);
1334
1335        // Reduced integral values always have a non-negative exponent.
1336        let mut b = if integral_check.exponent() >= 0 {
1337            // We believe our result should have been an integral
1338            integral_check
1339        } else {
1340            b
1341        };
1342
1343        numeric::munge_numeric(&mut b).unwrap();
1344        Ok(b)
1345    }
1346}
1347
1348#[sqlfunc(propagates_nulls = true)]
1349fn power(a: f64, b: f64) -> Result<f64, EvalError> {
1350    if a == 0.0 && b.is_sign_negative() {
1351        return Err(EvalError::Undefined(
1352            "zero raised to a negative power".into(),
1353        ));
1354    }
1355    if a.is_sign_negative() && b.fract() != 0.0 {
1356        // Equivalent to PG error:
1357        // > a negative number raised to a non-integer power yields a complex result
1358        return Err(EvalError::ComplexOutOfRange("pow".into()));
1359    }
1360    let res = a.powf(b);
1361    if res.is_infinite() {
1362        return Err(EvalError::FloatOverflow);
1363    }
1364    if res == 0.0 && a != 0.0 {
1365        return Err(EvalError::FloatUnderflow);
1366    }
1367    Ok(res)
1368}
1369
1370#[sqlfunc(propagates_nulls = true)]
1371fn uuid_generate_v5(a: uuid::Uuid, b: &str) -> uuid::Uuid {
1372    uuid::Uuid::new_v5(&a, b.as_bytes())
1373}
1374
1375#[sqlfunc(output_type = "Numeric", propagates_nulls = true)]
1376fn power_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1377    if a.is_zero() {
1378        if b.is_zero() {
1379            return Ok(Numeric::from(1));
1380        }
1381        if b.is_negative() {
1382            return Err(EvalError::Undefined(
1383                "zero raised to a negative power".into(),
1384            ));
1385        }
1386    }
1387    if a.is_negative() && b.exponent() < 0 {
1388        // Equivalent to PG error:
1389        // > a negative number raised to a non-integer power yields a complex result
1390        return Err(EvalError::ComplexOutOfRange("pow".into()));
1391    }
1392    let mut cx = numeric::cx_datum();
1393    cx.pow(&mut a, &b);
1394    let cx_status = cx.status();
1395    if cx_status.overflow() || (cx_status.invalid_operation() && !b.is_negative()) {
1396        Err(EvalError::FloatOverflow)
1397    } else if cx_status.subnormal() || cx_status.invalid_operation() {
1398        Err(EvalError::FloatUnderflow)
1399    } else {
1400        numeric::munge_numeric(&mut a).unwrap();
1401        Ok(a)
1402    }
1403}
1404
1405#[sqlfunc(propagates_nulls = true)]
1406fn get_bit(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1407    let err = EvalError::IndexOutOfRange {
1408        provided: index,
1409        valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap_or(i32::MAX) - 1,
1410    };
1411
1412    let index = usize::try_from(index).map_err(|_| err.clone())?;
1413
1414    let byte_index = index / 8;
1415    let bit_index = index % 8;
1416
1417    let i = bytes
1418        .get(byte_index)
1419        .map(|b| (*b >> bit_index) & 1)
1420        .ok_or(err)?;
1421    assert!(i == 0 || i == 1);
1422    Ok(i32::from(i))
1423}
1424
1425#[sqlfunc(propagates_nulls = true)]
1426fn get_byte(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1427    let err = EvalError::IndexOutOfRange {
1428        provided: index,
1429        valid_end: i32::try_from(bytes.len()).unwrap_or(i32::MAX) - 1,
1430    };
1431    let i: &u8 = bytes
1432        .get(usize::try_from(index).map_err(|_| err.clone())?)
1433        .ok_or(err)?;
1434    Ok(i32::from(*i))
1435}
1436
1437#[sqlfunc(sqlname = "constant_time_compare_bytes", propagates_nulls = true)]
1438pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
1439    verify_slices_are_equal(a, b).is_ok()
1440}
1441
1442#[sqlfunc(sqlname = "constant_time_compare_strings", propagates_nulls = true)]
1443pub fn constant_time_eq_string(a: &str, b: &str) -> bool {
1444    verify_slices_are_equal(a.as_bytes(), b.as_bytes()).is_ok()
1445}
1446
1447#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1448fn range_contains_i32<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1449    a.contains_elem(&b)
1450}
1451
1452#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1453fn range_contains_i64<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1454    a.contains_elem(&elem)
1455}
1456
1457#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1458fn range_contains_date<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1459    a.contains_elem(&elem)
1460}
1461
1462#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1463fn range_contains_numeric<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1464    a.contains_elem(&elem)
1465}
1466
1467#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1468fn range_contains_timestamp<'a>(
1469    a: Range<Datum<'a>>,
1470    elem: CheckedTimestamp<NaiveDateTime>,
1471) -> bool {
1472    a.contains_elem(&elem)
1473}
1474
1475#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1476fn range_contains_timestamp_tz<'a>(
1477    a: Range<Datum<'a>>,
1478    elem: CheckedTimestamp<DateTime<Utc>>,
1479) -> bool {
1480    a.contains_elem(&elem)
1481}
1482
1483#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1484fn range_contains_i32_rev<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1485    a.contains_elem(&b)
1486}
1487
1488#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1489fn range_contains_i64_rev<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1490    a.contains_elem(&elem)
1491}
1492
1493#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1494fn range_contains_date_rev<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1495    a.contains_elem(&elem)
1496}
1497
1498#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1499fn range_contains_numeric_rev<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1500    a.contains_elem(&elem)
1501}
1502
1503#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1504fn range_contains_timestamp_rev<'a>(
1505    a: Range<Datum<'a>>,
1506    elem: CheckedTimestamp<NaiveDateTime>,
1507) -> bool {
1508    a.contains_elem(&elem)
1509}
1510
1511#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1512fn range_contains_timestamp_tz_rev<'a>(
1513    a: Range<Datum<'a>>,
1514    elem: CheckedTimestamp<DateTime<Utc>>,
1515) -> bool {
1516    a.contains_elem(&elem)
1517}
1518
1519/// Macro to define binary function for various range operations.
1520/// Parameters:
1521/// 1. Unique binary function symbol.
1522/// 2. Range function symbol.
1523/// 3. SQL name for the function.
1524macro_rules! range_fn {
1525    ($fn:expr, $range_fn:expr, $sqlname:expr) => {
1526        paste::paste! {
1527
1528            #[sqlfunc(
1529                output_type = "bool",
1530                is_infix_op = true,
1531                sqlname = $sqlname,
1532                propagates_nulls = true
1533            )]
1534            fn [< range_ $fn >]<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a>
1535            {
1536                if a.is_null() || b.is_null() { return Datum::Null }
1537                let l = a.unwrap_range();
1538                let r = b.unwrap_range();
1539                Datum::from(Range::<Datum<'a>>::$range_fn(&l, &r))
1540            }
1541        }
1542    };
1543}
1544
1545// RangeContainsRange is either @> or <@ depending on the order of the arguments.
1546// It doesn't influence the result, but it does influence the display string.
1547range_fn!(contains_range, contains_range, "@>");
1548range_fn!(contains_range_rev, contains_range, "<@");
1549range_fn!(overlaps, overlaps, "&&");
1550range_fn!(after, after, ">>");
1551range_fn!(before, before, "<<");
1552range_fn!(overleft, overleft, "&<");
1553range_fn!(overright, overright, "&>");
1554range_fn!(adjacent, adjacent, "-|-");
1555
1556#[sqlfunc(is_infix_op = true, sqlname = "+")]
1557fn range_union<T: Copy + Ord>(l: Range<T>, r: Range<T>) -> Result<Range<T>, EvalError> {
1558    Ok(l.union(&r)?)
1559}
1560
1561#[sqlfunc(is_infix_op = true, sqlname = "*")]
1562fn range_intersection<T: Copy + Ord>(l: Range<T>, r: Range<T>) -> Range<T> {
1563    l.intersection(&r)
1564}
1565
1566#[sqlfunc(
1567    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
1568    is_infix_op = true,
1569    sqlname = "-",
1570    propagates_nulls = true,
1571    introduces_nulls = false
1572)]
1573fn range_difference<'a>(
1574    l: Range<Datum<'a>>,
1575    r: Range<Datum<'a>>,
1576) -> Result<Range<Datum<'a>>, EvalError> {
1577    Ok(l.difference(&r)?)
1578}
1579
1580#[sqlfunc(is_infix_op = true, sqlname = "=", negate = "Some(NotEq.into())")]
1581fn eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1582    // SQL equality demands that if either input is null, then the result should be null. However,
1583    // we don't need to handle this case here; it is handled when `BinaryFunc::eval` checks
1584    // `propagates_nulls`.
1585    a == b
1586}
1587
1588#[sqlfunc(is_infix_op = true, sqlname = "!=", negate = "Some(Eq.into())")]
1589fn not_eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1590    a != b
1591}
1592
1593#[sqlfunc(
1594    is_monotone = "(true, true)",
1595    is_infix_op = true,
1596    sqlname = "<",
1597    negate = "Some(Gte.into())"
1598)]
1599fn lt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1600    a < b
1601}
1602
1603#[sqlfunc(
1604    is_monotone = "(true, true)",
1605    is_infix_op = true,
1606    sqlname = "<=",
1607    negate = "Some(Gt.into())"
1608)]
1609fn lte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1610    a <= b
1611}
1612
1613#[sqlfunc(
1614    is_monotone = "(true, true)",
1615    is_infix_op = true,
1616    sqlname = ">",
1617    negate = "Some(Lte.into())"
1618)]
1619fn gt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1620    a > b
1621}
1622
1623#[sqlfunc(
1624    is_monotone = "(true, true)",
1625    is_infix_op = true,
1626    sqlname = ">=",
1627    negate = "Some(Lt.into())"
1628)]
1629fn gte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1630    a >= b
1631}
1632
1633#[sqlfunc(sqlname = "tocharts", propagates_nulls = true)]
1634fn to_char_timestamp_format(ts: CheckedTimestamp<chrono::NaiveDateTime>, format: &str) -> String {
1635    let fmt = DateTimeFormat::compile(format);
1636    fmt.render(&*ts)
1637}
1638
1639#[sqlfunc(sqlname = "tochartstz", propagates_nulls = true)]
1640fn to_char_timestamp_tz_format(
1641    ts: CheckedTimestamp<chrono::DateTime<Utc>>,
1642    format: &str,
1643) -> String {
1644    let fmt = DateTimeFormat::compile(format);
1645    fmt.render(&*ts)
1646}
1647
1648#[sqlfunc(sqlname = "->", is_infix_op = true)]
1649fn jsonb_get_int64<'a>(a: JsonbRef<'a>, i: i64) -> Option<JsonbRef<'a>> {
1650    match a.into_datum() {
1651        Datum::List(list) => {
1652            let i = if i >= 0 {
1653                usize::cast_from(i.unsigned_abs())
1654            } else {
1655                // index backwards from the end
1656                let i = usize::cast_from(i.unsigned_abs());
1657                (list.iter().count()).wrapping_sub(i)
1658            };
1659            let v = list.iter().nth(i)?;
1660            // `v` should be valid jsonb because it came from a jsonb list, but we don't
1661            // panic on mismatch to avoid bringing down the whole system on corrupt data.
1662            // Instead, we'll return None.
1663            JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1664        }
1665        Datum::Map(_) => None,
1666        _ => {
1667            // I have no idea why postgres does this, but we're stuck with it
1668            (i == 0 || i == -1).then_some(a)
1669        }
1670    }
1671}
1672
1673#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1674fn jsonb_get_int64_stringify<'a>(
1675    a: JsonbRef<'a>,
1676    i: i64,
1677    temp_storage: &'a RowArena,
1678) -> Option<&'a str> {
1679    let json = jsonb_get_int64(a, i)?;
1680    jsonb_stringify(json.into_datum(), temp_storage)
1681}
1682
1683#[sqlfunc(sqlname = "->", is_infix_op = true)]
1684fn jsonb_get_string<'a>(a: JsonbRef<'a>, k: &str) -> Option<JsonbRef<'a>> {
1685    let dict = DatumMap::try_from_result(Ok::<_, ()>(a.into_datum())).ok()?;
1686    let v = dict.iter().find(|(k2, _v)| k == *k2).map(|(_k, v)| v)?;
1687    JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1688}
1689
1690#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1691fn jsonb_get_string_stringify<'a>(
1692    a: JsonbRef<'a>,
1693    k: &str,
1694    temp_storage: &'a RowArena,
1695) -> Option<&'a str> {
1696    let v = jsonb_get_string(a, k)?;
1697    jsonb_stringify(v.into_datum(), temp_storage)
1698}
1699
1700#[sqlfunc(sqlname = "#>", is_infix_op = true)]
1701fn jsonb_get_path<'a>(mut json: JsonbRef<'a>, b: Array<'a>) -> Option<JsonbRef<'a>> {
1702    let path = b.elements();
1703    for key in path.iter() {
1704        let key = match key {
1705            Datum::String(s) => s,
1706            Datum::Null => return None,
1707            _ => unreachable!("keys in jsonb_get_path known to be strings"),
1708        };
1709        let v = match json.into_datum() {
1710            Datum::Map(map) => map.iter().find(|(k, _)| key == *k).map(|(_k, v)| v),
1711            Datum::List(list) => {
1712                let i = strconv::parse_int64(key).ok()?;
1713                let i = if i >= 0 {
1714                    usize::cast_from(i.unsigned_abs())
1715                } else {
1716                    // index backwards from the end
1717                    let i = usize::cast_from(i.unsigned_abs());
1718                    (list.iter().count()).wrapping_sub(i)
1719                };
1720                list.iter().nth(i)
1721            }
1722            _ => return None,
1723        }?;
1724        json = JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()?;
1725    }
1726    Some(json)
1727}
1728
1729#[sqlfunc(sqlname = "#>>", is_infix_op = true)]
1730fn jsonb_get_path_stringify<'a>(
1731    a: JsonbRef<'a>,
1732    b: Array<'a>,
1733    temp_storage: &'a RowArena,
1734) -> Option<&'a str> {
1735    let json = jsonb_get_path(a, b)?;
1736    jsonb_stringify(json.into_datum(), temp_storage)
1737}
1738
1739#[sqlfunc(is_infix_op = true, sqlname = "?")]
1740fn jsonb_contains_string<'a>(a: JsonbRef<'a>, k: &str) -> bool {
1741    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1742    // When the left operand is SQL NULL (NULL::jsonb), JsonbRef::try_from_result rejects it,
1743    // so the binary evaluator never calls this function and returns NULL (see binary.rs).
1744    // So, this function only runs for non-null jsonb; a.into_datum() never sees Datum::Null.
1745    match a.into_datum() {
1746        Datum::List(list) => list.iter().any(|k2| Datum::from(k) == k2),
1747        Datum::Map(dict) => dict.iter().any(|(k2, _v)| k == k2),
1748        Datum::String(string) => string == k,
1749        _ => false,
1750    }
1751}
1752
1753#[sqlfunc(is_infix_op = true, sqlname = "?", propagates_nulls = true)]
1754// Map keys are always text.
1755fn map_contains_key<'a>(map: DatumMap<'a>, k: &str) -> bool {
1756    map.iter().any(|(k2, _v)| k == k2)
1757}
1758
1759#[sqlfunc(is_infix_op = true, sqlname = "?&")]
1760fn map_contains_all_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1761    keys.elements()
1762        .iter()
1763        .all(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1764}
1765
1766#[sqlfunc(is_infix_op = true, sqlname = "?|", propagates_nulls = true)]
1767fn map_contains_any_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1768    keys.elements()
1769        .iter()
1770        .any(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1771}
1772
1773#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1774fn map_contains_map<'a>(map_a: DatumMap<'a>, b: DatumMap<'a>) -> bool {
1775    b.iter().all(|(b_key, b_val)| {
1776        map_a
1777            .iter()
1778            .any(|(a_key, a_val)| (a_key == b_key) && (a_val == b_val))
1779    })
1780}
1781
1782#[sqlfunc(is_infix_op = true, sqlname = "->", propagates_nulls = true)]
1783fn map_get_value<'a, T: FromDatum<'a>>(a: DatumMap<'a, T>, target_key: &str) -> Option<T> {
1784    a.typed_iter()
1785        .find(|(key, _v)| target_key == *key)
1786        .map(|(_k, v)| v)
1787}
1788
1789#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1790fn list_contains_list<'a>(a: ExcludeNull<DatumList<'a>>, b: ExcludeNull<DatumList<'a>>) -> bool {
1791    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
1792    if b.iter().contains(&Datum::Null) {
1793        false
1794    } else {
1795        b.iter()
1796            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
1797    }
1798}
1799
1800#[sqlfunc(is_infix_op = true, sqlname = "<@")]
1801fn list_contains_list_rev<'a>(
1802    a: ExcludeNull<DatumList<'a>>,
1803    b: ExcludeNull<DatumList<'a>>,
1804) -> bool {
1805    list_contains_list(b, a)
1806}
1807
1808// TODO(jamii) nested loops are possibly not the fastest way to do this
1809#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1810fn jsonb_contains_jsonb<'a>(a: JsonbRef<'a>, b: JsonbRef<'a>) -> bool {
1811    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1812    fn contains(a: Datum, b: Datum, at_top_level: bool) -> bool {
1813        match (a, b) {
1814            (Datum::JsonNull, Datum::JsonNull) => true,
1815            (Datum::False, Datum::False) => true,
1816            (Datum::True, Datum::True) => true,
1817            (Datum::Numeric(a), Datum::Numeric(b)) => a == b,
1818            (Datum::String(a), Datum::String(b)) => a == b,
1819            (Datum::List(a), Datum::List(b)) => b
1820                .iter()
1821                .all(|b_elem| a.iter().any(|a_elem| contains(a_elem, b_elem, false))),
1822            (Datum::Map(a), Datum::Map(b)) => b.iter().all(|(b_key, b_val)| {
1823                a.iter()
1824                    .any(|(a_key, a_val)| (a_key == b_key) && contains(a_val, b_val, false))
1825            }),
1826
1827            // fun special case
1828            (Datum::List(a), b) => {
1829                at_top_level && a.iter().any(|a_elem| contains(a_elem, b, false))
1830            }
1831
1832            _ => false,
1833        }
1834    }
1835    contains(a.into_datum(), b.into_datum(), true)
1836}
1837
1838#[sqlfunc(is_infix_op = true, sqlname = "||")]
1839fn jsonb_concat<'a>(
1840    a: JsonbRef<'a>,
1841    b: JsonbRef<'a>,
1842    temp_storage: &'a RowArena,
1843) -> Option<JsonbRef<'a>> {
1844    let res = match (a.into_datum(), b.into_datum()) {
1845        (Datum::Map(dict_a), Datum::Map(dict_b)) => {
1846            let mut pairs = dict_b.iter().chain(dict_a.iter()).collect::<Vec<_>>();
1847            // stable sort, so if keys collide dedup prefers dict_b
1848            pairs.sort_by(|(k1, _v1), (k2, _v2)| k1.cmp(k2));
1849            pairs.dedup_by(|(k1, _v1), (k2, _v2)| k1 == k2);
1850            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1851        }
1852        (Datum::List(list_a), Datum::List(list_b)) => {
1853            let elems = list_a.iter().chain(list_b.iter());
1854            temp_storage.make_datum(|packer| packer.push_list(elems))
1855        }
1856        (Datum::List(list_a), b) => {
1857            let elems = list_a.iter().chain(Some(b));
1858            temp_storage.make_datum(|packer| packer.push_list(elems))
1859        }
1860        (a, Datum::List(list_b)) => {
1861            let elems = Some(a).into_iter().chain(list_b.iter());
1862            temp_storage.make_datum(|packer| packer.push_list(elems))
1863        }
1864        _ => return None,
1865    };
1866    Some(JsonbRef::from_datum(res))
1867}
1868
1869#[sqlfunc(
1870    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1871    is_infix_op = true,
1872    sqlname = "-",
1873    propagates_nulls = true,
1874    introduces_nulls = true
1875)]
1876fn jsonb_delete_int64<'a>(a: Datum<'a>, i: i64, temp_storage: &'a RowArena) -> Datum<'a> {
1877    match a {
1878        Datum::List(list) => {
1879            let i = if i >= 0 {
1880                usize::cast_from(i.unsigned_abs())
1881            } else {
1882                // index backwards from the end
1883                let i = usize::cast_from(i.unsigned_abs());
1884                (list.iter().count()).wrapping_sub(i)
1885            };
1886            let elems = list
1887                .iter()
1888                .enumerate()
1889                .filter(|(i2, _e)| i != *i2)
1890                .map(|(_, e)| e);
1891            temp_storage.make_datum(|packer| packer.push_list(elems))
1892        }
1893        _ => Datum::Null,
1894    }
1895}
1896
1897#[sqlfunc(
1898    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1899    is_infix_op = true,
1900    sqlname = "-",
1901    propagates_nulls = true,
1902    introduces_nulls = true
1903)]
1904fn jsonb_delete_string<'a>(a: Datum<'a>, k: &str, temp_storage: &'a RowArena) -> Datum<'a> {
1905    match a {
1906        Datum::List(list) => {
1907            let elems = list.iter().filter(|e| Datum::from(k) != *e);
1908            temp_storage.make_datum(|packer| packer.push_list(elems))
1909        }
1910        Datum::Map(dict) => {
1911            let pairs = dict.iter().filter(|(k2, _v)| k != *k2);
1912            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1913        }
1914        _ => Datum::Null,
1915    }
1916}
1917
1918#[sqlfunc(
1919    sqlname = "extractiv",
1920    propagates_nulls = true,
1921    introduces_nulls = false
1922)]
1923fn date_part_interval_numeric(units: &str, b: Interval) -> Result<Numeric, EvalError> {
1924    match units.parse() {
1925        Ok(units) => Ok(date_part_interval_inner::<Numeric>(units, b)?),
1926        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1927    }
1928}
1929
1930#[sqlfunc(
1931    sqlname = "date_partiv",
1932    propagates_nulls = true,
1933    introduces_nulls = false
1934)]
1935fn date_part_interval_f64(units: &str, b: Interval) -> Result<f64, EvalError> {
1936    match units.parse() {
1937        Ok(units) => Ok(date_part_interval_inner::<f64>(units, b)?),
1938        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1939    }
1940}
1941
1942#[sqlfunc(
1943    sqlname = "extractt",
1944    propagates_nulls = true,
1945    introduces_nulls = false
1946)]
1947fn date_part_time_numeric(units: &str, b: chrono::NaiveTime) -> Result<Numeric, EvalError> {
1948    match units.parse() {
1949        Ok(units) => Ok(date_part_time_inner::<Numeric>(units, b)?),
1950        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1951    }
1952}
1953
1954#[sqlfunc(
1955    sqlname = "date_partt",
1956    propagates_nulls = true,
1957    introduces_nulls = false
1958)]
1959fn date_part_time_f64(units: &str, b: chrono::NaiveTime) -> Result<f64, EvalError> {
1960    match units.parse() {
1961        Ok(units) => Ok(date_part_time_inner::<f64>(units, b)?),
1962        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1963    }
1964}
1965
1966#[sqlfunc(sqlname = "extractts", propagates_nulls = true)]
1967fn date_part_timestamp_timestamp_numeric(
1968    units: &str,
1969    ts: CheckedTimestamp<NaiveDateTime>,
1970) -> Result<Numeric, EvalError> {
1971    match units.parse() {
1972        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1973        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1974    }
1975}
1976
1977#[sqlfunc(sqlname = "extracttstz", propagates_nulls = true)]
1978fn date_part_timestamp_timestamp_tz_numeric(
1979    units: &str,
1980    ts: CheckedTimestamp<DateTime<Utc>>,
1981) -> Result<Numeric, EvalError> {
1982    match units.parse() {
1983        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1984        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1985    }
1986}
1987
1988#[sqlfunc(sqlname = "date_partts", propagates_nulls = true)]
1989fn date_part_timestamp_timestamp_f64(
1990    units: &str,
1991    ts: CheckedTimestamp<NaiveDateTime>,
1992) -> Result<f64, EvalError> {
1993    match units.parse() {
1994        Ok(units) => date_part_timestamp_inner(units, &*ts),
1995        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1996    }
1997}
1998
1999#[sqlfunc(sqlname = "date_parttstz", propagates_nulls = true)]
2000fn date_part_timestamp_timestamp_tz_f64(
2001    units: &str,
2002    ts: CheckedTimestamp<DateTime<Utc>>,
2003) -> Result<f64, EvalError> {
2004    match units.parse() {
2005        Ok(units) => date_part_timestamp_inner(units, &*ts),
2006        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2007    }
2008}
2009
2010#[sqlfunc(sqlname = "extractd", propagates_nulls = true)]
2011fn extract_date_units(units: &str, b: Date) -> Result<Numeric, EvalError> {
2012    match units.parse() {
2013        Ok(units) => Ok(extract_date_inner(units, b.into())?),
2014        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2015    }
2016}
2017
2018pub fn date_bin<T>(
2019    stride: Interval,
2020    source: CheckedTimestamp<T>,
2021    origin: CheckedTimestamp<T>,
2022) -> Result<CheckedTimestamp<T>, EvalError>
2023where
2024    T: TimestampLike,
2025{
2026    if stride.months != 0 {
2027        return Err(EvalError::DateBinOutOfRange(
2028            "timestamps cannot be binned into intervals containing months or years".into(),
2029        ));
2030    }
2031
2032    let stride_ns = match stride.duration_as_chrono().num_nanoseconds() {
2033        Some(ns) if ns <= 0 => Err(EvalError::DateBinOutOfRange(
2034            "stride must be greater than zero".into(),
2035        )),
2036        Some(ns) => Ok(ns),
2037        None => Err(EvalError::DateBinOutOfRange(
2038            format!("stride cannot exceed {}/{} nanoseconds", i64::MAX, i64::MIN,).into(),
2039        )),
2040    }?;
2041
2042    // Make sure the returned timestamp is at the start of the bin, even if the
2043    // origin is in the future. We do this here because `T` is not `Copy` and
2044    // gets moved by its subtraction operation.
2045    let sub_stride = origin > source;
2046
2047    let tm_diff = (source - origin.clone()).num_nanoseconds().ok_or_else(|| {
2048        EvalError::DateBinOutOfRange(
2049            "source and origin must not differ more than 2^63 nanoseconds".into(),
2050        )
2051    })?;
2052
2053    let remainder = tm_diff % stride_ns;
2054    let mut tm_delta = tm_diff - remainder;
2055
2056    if sub_stride && remainder != 0 {
2057        tm_delta = tm_delta.checked_sub(stride_ns).ok_or_else(|| {
2058            EvalError::DateBinOutOfRange(
2059                "source and origin must not differ more than 2^63 nanoseconds".into(),
2060            )
2061        })?;
2062    }
2063
2064    let res = origin
2065        .checked_add_signed(Duration::nanoseconds(tm_delta))
2066        .ok_or(EvalError::TimestampOutOfRange)?;
2067    Ok(CheckedTimestamp::from_timestamplike(res)?)
2068}
2069
2070// Non-monotone in `stride`: the result is `origin + floor((source - origin) /
2071// stride) * stride`. For a fixed source like `2024-01-01 12:00:00`, a 1-day
2072// stride bins to `2024-01-01 00:00:00`, but a 2-day stride bins to
2073// `2023-12-31 00:00:00` — i.e. the lex-larger interval produces an earlier
2074// timestamp. Monotone in `source`.
2075#[sqlfunc(is_monotone = "(false, true)", sqlname = "bin_unix_epoch_timestamp")]
2076fn date_bin_timestamp(
2077    stride: Interval,
2078    source: CheckedTimestamp<NaiveDateTime>,
2079) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2080    let origin =
2081        CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap().naive_utc())
2082            .expect("must fit");
2083    date_bin(stride, source, origin)
2084}
2085
2086// See `date_bin_timestamp` for why this is not monotone in `stride`.
2087#[sqlfunc(is_monotone = "(false, true)", sqlname = "bin_unix_epoch_timestamptz")]
2088fn date_bin_timestamp_tz(
2089    stride: Interval,
2090    source: CheckedTimestamp<DateTime<Utc>>,
2091) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2092    let origin = CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap())
2093        .expect("must fit");
2094    date_bin(stride, source, origin)
2095}
2096
2097#[sqlfunc(sqlname = "date_truncts", propagates_nulls = true)]
2098fn date_trunc_units_timestamp(
2099    units: &str,
2100    ts: CheckedTimestamp<NaiveDateTime>,
2101) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2102    match units.parse() {
2103        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2104        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2105    }
2106}
2107
2108#[sqlfunc(sqlname = "date_trunctstz", propagates_nulls = true)]
2109fn date_trunc_units_timestamp_tz(
2110    units: &str,
2111    ts: CheckedTimestamp<DateTime<Utc>>,
2112) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2113    match units.parse() {
2114        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2115        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2116    }
2117}
2118
2119#[sqlfunc(sqlname = "date_trunciv", propagates_nulls = true)]
2120fn date_trunc_interval(units: &str, mut interval: Interval) -> Result<Interval, EvalError> {
2121    let dtf = units
2122        .parse()
2123        .map_err(|_| EvalError::UnknownUnits(units.into()))?;
2124
2125    interval
2126        .truncate_low_fields(dtf, Some(0), RoundBehavior::Truncate)
2127        .expect(
2128            "truncate_low_fields should not fail with max_precision 0 and RoundBehavior::Truncate",
2129        );
2130    Ok(interval)
2131}
2132
2133/// Parses a named timezone like `EST` or `America/New_York`, or a fixed-offset timezone like `-05:00`.
2134///
2135/// The interpretation of fixed offsets depend on whether the POSIX or ISO 8601 standard is being
2136/// used.
2137pub(crate) fn parse_timezone(tz: &str, spec: TimezoneSpec) -> Result<Timezone, EvalError> {
2138    Timezone::parse(tz, spec).map_err(|_| EvalError::InvalidTimezone(tz.into()))
2139}
2140
2141/// Converts the time datum `b`, which is assumed to be in UTC, to the timezone that the interval datum `a` is assumed
2142/// to represent. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2143/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2144#[sqlfunc(sqlname = "timezoneit")]
2145fn timezone_interval_time_binary(
2146    interval: Interval,
2147    time: chrono::NaiveTime,
2148) -> Result<chrono::NaiveTime, EvalError> {
2149    if interval.months != 0 {
2150        Err(EvalError::InvalidTimezoneInterval)
2151    } else {
2152        Ok(time.overflowing_add_signed(interval.duration_as_chrono()).0)
2153    }
2154}
2155
2156/// Converts the timestamp datum `b`, which is assumed to be in the time of the timezone datum `a` to a timestamptz
2157/// in UTC. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2158/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2159#[sqlfunc(sqlname = "timezoneits")]
2160fn timezone_interval_timestamp_binary(
2161    interval: Interval,
2162    ts: CheckedTimestamp<NaiveDateTime>,
2163) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2164    if interval.months != 0 {
2165        Err(EvalError::InvalidTimezoneInterval)
2166    } else {
2167        match ts.checked_sub_signed(interval.duration_as_chrono()) {
2168            Some(sub) => Ok(DateTime::from_naive_utc_and_offset(sub, Utc).try_into()?),
2169            None => Err(EvalError::TimestampOutOfRange),
2170        }
2171    }
2172}
2173
2174/// Converts the UTC timestamptz datum `b`, to the local timestamp of the timezone datum `a`.
2175/// The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2176/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2177#[sqlfunc(sqlname = "timezoneitstz")]
2178fn timezone_interval_timestamp_tz_binary(
2179    interval: Interval,
2180    tstz: CheckedTimestamp<DateTime<Utc>>,
2181) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2182    if interval.months != 0 {
2183        return Err(EvalError::InvalidTimezoneInterval);
2184    }
2185    match tstz
2186        .naive_utc()
2187        .checked_add_signed(interval.duration_as_chrono())
2188    {
2189        Some(dt) => Ok(dt.try_into()?),
2190        None => Err(EvalError::TimestampOutOfRange),
2191    }
2192}
2193
2194#[sqlfunc(
2195    output_type_expr = r#"SqlScalarType::Record {
2196                fields: [
2197                    ("abbrev".into(), SqlScalarType::String.nullable(false)),
2198                    ("base_utc_offset".into(), SqlScalarType::Interval.nullable(false)),
2199                    ("dst_offset".into(), SqlScalarType::Interval.nullable(false)),
2200                ].into(),
2201                custom_id: None,
2202            }.nullable(true)"#,
2203    propagates_nulls = true,
2204    introduces_nulls = false
2205)]
2206fn timezone_offset<'a>(
2207    tz_str: &str,
2208    b: CheckedTimestamp<chrono::DateTime<Utc>>,
2209    temp_storage: &'a RowArena,
2210) -> Result<Datum<'a>, EvalError> {
2211    let tz = match Tz::from_str_insensitive(tz_str) {
2212        Ok(tz) => tz,
2213        Err(_) => return Err(EvalError::InvalidIanaTimezoneId(tz_str.into())),
2214    };
2215    let offset = tz.offset_from_utc_datetime(&b.naive_utc());
2216    Ok(temp_storage.make_datum(|packer| {
2217        packer.push_list_with(|packer| {
2218            packer.push(Datum::from(offset.abbreviation()));
2219            packer.push(Datum::from(offset.base_utc_offset()));
2220            packer.push(Datum::from(offset.dst_offset()));
2221        });
2222    }))
2223}
2224
2225/// Determines if an mz_aclitem contains one of the specified privileges. This will return true if
2226/// any of the listed privileges are contained in the mz_aclitem.
2227#[sqlfunc(
2228    sqlname = "mz_aclitem_contains_privilege",
2229    output_type = "bool",
2230    propagates_nulls = true
2231)]
2232fn mz_acl_item_contains_privilege(
2233    mz_acl_item: MzAclItem,
2234    privileges: &str,
2235) -> Result<bool, EvalError> {
2236    let acl_mode = AclMode::parse_multiple_privileges(privileges)
2237        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
2238    let contains = !mz_acl_item.acl_mode.intersection(acl_mode).is_empty();
2239    Ok(contains)
2240}
2241
2242#[sqlfunc]
2243// transliterated from postgres/src/backend/utils/adt/misc.c
2244fn parse_ident<'a>(ident: &'a str, strict: bool) -> Result<ArrayRustType<Cow<'a, str>>, EvalError> {
2245    fn is_ident_start(c: char) -> bool {
2246        matches!(c, 'A'..='Z' | 'a'..='z' | '_' | '\u{80}'..=char::MAX)
2247    }
2248
2249    fn is_ident_cont(c: char) -> bool {
2250        matches!(c, '0'..='9' | '$') || is_ident_start(c)
2251    }
2252
2253    let mut elems = vec![];
2254    let buf = &mut LexBuf::new(ident);
2255
2256    let mut after_dot = false;
2257
2258    buf.take_while(|ch| ch.is_ascii_whitespace());
2259
2260    loop {
2261        let mut missing_ident = true;
2262
2263        let c = buf.next();
2264
2265        if c == Some('"') {
2266            let s = buf.take_while(|ch| !matches!(ch, '"'));
2267
2268            if buf.next() != Some('"') {
2269                return Err(EvalError::InvalidIdentifier {
2270                    ident: ident.into(),
2271                    detail: Some("String has unclosed double quotes.".into()),
2272                });
2273            }
2274            elems.push(Cow::Borrowed(s));
2275            missing_ident = false;
2276        } else if c.map(is_ident_start).unwrap_or(false) {
2277            buf.prev();
2278            let s = buf.take_while(is_ident_cont);
2279            elems.push(Cow::Owned(s.to_ascii_lowercase()));
2280            missing_ident = false;
2281        }
2282
2283        if missing_ident {
2284            if c == Some('.') {
2285                return Err(EvalError::InvalidIdentifier {
2286                    ident: ident.into(),
2287                    detail: Some("No valid identifier before \".\".".into()),
2288                });
2289            } else if after_dot {
2290                return Err(EvalError::InvalidIdentifier {
2291                    ident: ident.into(),
2292                    detail: Some("No valid identifier after \".\".".into()),
2293                });
2294            } else {
2295                return Err(EvalError::InvalidIdentifier {
2296                    ident: ident.into(),
2297                    detail: None,
2298                });
2299            }
2300        }
2301
2302        buf.take_while(|ch| ch.is_ascii_whitespace());
2303
2304        match buf.next() {
2305            Some('.') => {
2306                after_dot = true;
2307
2308                buf.take_while(|ch| ch.is_ascii_whitespace());
2309            }
2310            Some(_) if strict => {
2311                return Err(EvalError::InvalidIdentifier {
2312                    ident: ident.into(),
2313                    detail: None,
2314                });
2315            }
2316            _ => break,
2317        }
2318    }
2319
2320    Ok(elems.into())
2321}
2322
2323fn regexp_split_to_array_re<'a>(
2324    text: &str,
2325    regexp: &Regex,
2326    temp_storage: &'a RowArena,
2327) -> Result<Datum<'a>, EvalError> {
2328    let found = mz_regexp::regexp_split_to_array(text, regexp);
2329    let mut row = Row::default();
2330    let mut packer = row.packer();
2331    packer.try_push_array(
2332        &[ArrayDimension {
2333            lower_bound: 1,
2334            length: found.len(),
2335        }],
2336        found.into_iter().map(Datum::String),
2337    )?;
2338    Ok(temp_storage.push_unary_row(row))
2339}
2340
2341#[sqlfunc(propagates_nulls = true)]
2342fn pretty_sql<'a>(sql: &str, width: i32, temp_storage: &'a RowArena) -> Result<&'a str, EvalError> {
2343    let width =
2344        usize::try_from(width).map_err(|_| EvalError::PrettyError("invalid width".into()))?;
2345    let pretty = pretty_str(
2346        sql,
2347        PrettyConfig {
2348            width,
2349            format_mode: FormatMode::Simple,
2350        },
2351    )
2352    .map_err(|e| EvalError::PrettyError(e.to_string().into()))?;
2353    let pretty = temp_storage.push_string(pretty);
2354    Ok(pretty)
2355}
2356
2357#[sqlfunc]
2358fn redact_sql(sql: &str) -> Result<String, EvalError> {
2359    let stmts = mz_sql_parser::parser::parse_statements(sql)
2360        .map_err(|e| EvalError::RedactError(e.to_string().into()))?;
2361    match stmts.len() {
2362        1 => Ok(stmts[0].ast.to_ast_string_redacted()),
2363        n => Err(EvalError::RedactError(
2364            format!("expected a single statement, found {n}").into(),
2365        )),
2366    }
2367}
2368
2369#[sqlfunc(propagates_nulls = true)]
2370fn starts_with(a: &str, b: &str) -> bool {
2371    a.starts_with(b)
2372}
2373
2374#[sqlfunc(
2375    sqlname = "||",
2376    is_infix_op = true,
2377    propagates_nulls = true,
2378    // Text concatenation is monotonic in its second argument, because if I change the
2379    // second argument but don't change the first argument, then we won't find a difference
2380    // in that part of the concatenation result that came from the first argument, so we'll
2381    // find the difference that comes from changing the second argument.
2382    // (It's not monotonic in its first argument, because e.g.,
2383    // 'A' < 'AA' but 'AZ' > 'AAZ'.)
2384    is_monotone = (false, true),
2385)]
2386fn text_concat_binary(a: &str, b: &str) -> Result<String, EvalError> {
2387    if a.len() + b.len() > MAX_STRING_FUNC_RESULT_BYTES {
2388        return Err(EvalError::LengthTooLarge);
2389    }
2390    let mut buf = String::with_capacity(a.len() + b.len());
2391    buf.push_str(a);
2392    buf.push_str(b);
2393    Ok(buf)
2394}
2395
2396#[sqlfunc(propagates_nulls = true, introduces_nulls = false)]
2397fn like_escape<'a>(
2398    pattern: &str,
2399    b: &str,
2400    temp_storage: &'a RowArena,
2401) -> Result<&'a str, EvalError> {
2402    let escape = like_pattern::EscapeBehavior::from_str(b)?;
2403    let normalized = like_pattern::normalize_pattern(pattern, escape)?;
2404    Ok(temp_storage.push_string(normalized))
2405}
2406
2407#[sqlfunc(is_infix_op = true, sqlname = "like")]
2408fn is_like_match_case_sensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2409    like_pattern::compile(pattern, false).map(|needle| needle.is_match(haystack))
2410}
2411
2412#[sqlfunc(is_infix_op = true, sqlname = "ilike")]
2413fn is_like_match_case_insensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2414    like_pattern::compile(pattern, true).map(|needle| needle.is_match(haystack))
2415}
2416
2417#[sqlfunc(is_infix_op = true, sqlname = "~")]
2418fn is_regexp_match_case_sensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2419    let regex = build_regex(needle, "")?;
2420    Ok(regex.is_match(haystack))
2421}
2422
2423#[sqlfunc(is_infix_op = true, sqlname = "~*")]
2424fn is_regexp_match_case_insensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2425    let regex = build_regex(needle, "i")?;
2426    Ok(regex.is_match(haystack))
2427}
2428
2429fn regexp_match_static<'a>(
2430    haystack: Datum<'a>,
2431    temp_storage: &'a RowArena,
2432    needle: &regex::Regex,
2433) -> Result<Datum<'a>, EvalError> {
2434    let mut row = Row::default();
2435    let mut packer = row.packer();
2436    if needle.captures_len() > 1 {
2437        // The regex contains capture groups, so return an array containing the
2438        // matched text in each capture group, unless the entire match fails.
2439        // Individual capture groups may also be null if that group did not
2440        // participate in the match.
2441        match needle.captures(haystack.unwrap_str()) {
2442            None => packer.push(Datum::Null),
2443            Some(captures) => packer.try_push_array(
2444                &[ArrayDimension {
2445                    lower_bound: 1,
2446                    length: captures.len() - 1,
2447                }],
2448                // Skip the 0th capture group, which is the whole match.
2449                captures.iter().skip(1).map(|mtch| match mtch {
2450                    None => Datum::Null,
2451                    Some(mtch) => Datum::String(mtch.as_str()),
2452                }),
2453            )?,
2454        }
2455    } else {
2456        // The regex contains no capture groups, so return a one-element array
2457        // containing the match, or null if there is no match.
2458        match needle.find(haystack.unwrap_str()) {
2459            None => packer.push(Datum::Null),
2460            Some(mtch) => packer.try_push_array(
2461                &[ArrayDimension {
2462                    lower_bound: 1,
2463                    length: 1,
2464                }],
2465                iter::once(Datum::String(mtch.as_str())),
2466            )?,
2467        };
2468    };
2469    Ok(temp_storage.push_unary_row(row))
2470}
2471
2472/// Sets `limit` based on the presence of 'g' in `flags` for use in `Regex::replacen`,
2473/// and removes 'g' from `flags` if present.
2474pub(crate) fn regexp_replace_parse_flags(flags: &str) -> (usize, Cow<'_, str>) {
2475    // 'g' means to replace all instead of the first. Use a Cow to avoid allocating in the fast
2476    // path. We could switch build_regex to take an iter which would also achieve that.
2477    let (limit, flags) = if flags.contains('g') {
2478        let flags = flags.replace('g', "");
2479        (0, Cow::Owned(flags))
2480    } else {
2481        (1, Cow::Borrowed(flags))
2482    };
2483    (limit, flags)
2484}
2485
2486pub fn build_regex(needle: &str, flags: &str) -> Result<Regex, EvalError> {
2487    let mut case_insensitive = false;
2488    // Note: Postgres accepts it when both flags are present, taking the last one. We do the same.
2489    for f in flags.chars() {
2490        match f {
2491            'i' => {
2492                case_insensitive = true;
2493            }
2494            'c' => {
2495                case_insensitive = false;
2496            }
2497            _ => return Err(EvalError::InvalidRegexFlag(f)),
2498        }
2499    }
2500    Ok(Regex::new(needle, case_insensitive)?)
2501}
2502
2503#[sqlfunc(sqlname = "repeat")]
2504fn repeat_string(string: &str, count: i32) -> Result<String, EvalError> {
2505    let len = usize::try_from(count).unwrap_or(0);
2506    if (len * string.len()) > MAX_STRING_FUNC_RESULT_BYTES {
2507        return Err(EvalError::LengthTooLarge);
2508    }
2509    Ok(string.repeat(len))
2510}
2511
2512/// Constructs a new zero or one dimensional array out of an arbitrary number of
2513/// scalars.
2514///
2515/// If `datums` is empty, constructs a zero-dimensional array. Otherwise,
2516/// constructs a one dimensional array whose lower bound is one and whose length
2517/// is equal to `datums.len()`.
2518fn array_create_scalar<'a>(
2519    datums: &[Datum<'a>],
2520    temp_storage: &'a RowArena,
2521) -> Result<Datum<'a>, EvalError> {
2522    let mut dims = &[ArrayDimension {
2523        lower_bound: 1,
2524        length: datums.len(),
2525    }][..];
2526    if datums.is_empty() {
2527        // Per PostgreSQL, empty arrays are represented with zero dimensions,
2528        // not one dimension of zero length. We write this condition a little
2529        // strangely to satisfy the borrow checker while avoiding an allocation.
2530        dims = &[];
2531    }
2532    let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(dims, datums))?;
2533    Ok(datum)
2534}
2535
2536fn stringify_datum<'a, B>(
2537    buf: &mut B,
2538    d: Datum<'a>,
2539    ty: &SqlScalarType,
2540) -> Result<strconv::Nestable, EvalError>
2541where
2542    B: FormatBuffer,
2543{
2544    use SqlScalarType::*;
2545    match &ty {
2546        AclItem => Ok(strconv::format_acl_item(buf, d.unwrap_acl_item())),
2547        Bool => Ok(strconv::format_bool(buf, d.unwrap_bool())),
2548        Int16 => Ok(strconv::format_int16(buf, d.unwrap_int16())),
2549        Int32 => Ok(strconv::format_int32(buf, d.unwrap_int32())),
2550        Int64 => Ok(strconv::format_int64(buf, d.unwrap_int64())),
2551        UInt16 => Ok(strconv::format_uint16(buf, d.unwrap_uint16())),
2552        UInt32 | Oid | RegClass | RegProc | RegType => {
2553            Ok(strconv::format_uint32(buf, d.unwrap_uint32()))
2554        }
2555        UInt64 => Ok(strconv::format_uint64(buf, d.unwrap_uint64())),
2556        Float32 => Ok(strconv::format_float32(buf, d.unwrap_float32())),
2557        Float64 => Ok(strconv::format_float64(buf, d.unwrap_float64())),
2558        Numeric { .. } => Ok(strconv::format_numeric(buf, &d.unwrap_numeric())),
2559        Date => Ok(strconv::format_date(buf, d.unwrap_date())),
2560        Time => Ok(strconv::format_time(buf, d.unwrap_time())),
2561        Timestamp { .. } => Ok(strconv::format_timestamp(buf, &d.unwrap_timestamp())),
2562        TimestampTz { .. } => Ok(strconv::format_timestamptz(buf, &d.unwrap_timestamptz())),
2563        Interval => Ok(strconv::format_interval(buf, d.unwrap_interval())),
2564        Bytes => Ok(strconv::format_bytes(buf, d.unwrap_bytes())),
2565        String | VarChar { .. } | PgLegacyName => Ok(strconv::format_string(buf, d.unwrap_str())),
2566        Char { length } => Ok(strconv::format_string(
2567            buf,
2568            &mz_repr::adt::char::format_str_pad(d.unwrap_str(), *length),
2569        )),
2570        PgLegacyChar => {
2571            format_pg_legacy_char(buf, d.unwrap_uint8())?;
2572            Ok(strconv::Nestable::MayNeedEscaping)
2573        }
2574        Jsonb => Ok(strconv::format_jsonb(buf, JsonbRef::from_datum(d))),
2575        Uuid => Ok(strconv::format_uuid(buf, d.unwrap_uuid())),
2576        Record { fields, .. } => {
2577            let mut fields = fields.iter();
2578            strconv::format_record(buf, d.unwrap_list(), |buf, d| {
2579                let (_name, ty) = fields.next().unwrap();
2580                if d.is_null() {
2581                    Ok(buf.write_null())
2582                } else {
2583                    stringify_datum(buf.nonnull_buffer(), d, &ty.scalar_type)
2584                }
2585            })
2586        }
2587        Array(elem_type) => strconv::format_array(
2588            buf,
2589            &d.unwrap_array().dims().into_iter().collect::<Vec<_>>(),
2590            d.unwrap_array().elements(),
2591            |buf, d| {
2592                if d.is_null() {
2593                    Ok(buf.write_null())
2594                } else {
2595                    stringify_datum(buf.nonnull_buffer(), d, elem_type)
2596                }
2597            },
2598        ),
2599        List { element_type, .. } => strconv::format_list(buf, d.unwrap_list(), |buf, d| {
2600            if d.is_null() {
2601                Ok(buf.write_null())
2602            } else {
2603                stringify_datum(buf.nonnull_buffer(), d, element_type)
2604            }
2605        }),
2606        Map { value_type, .. } => strconv::format_map(buf, &d.unwrap_map(), |buf, d| {
2607            if d.is_null() {
2608                Ok(buf.write_null())
2609            } else {
2610                stringify_datum(buf.nonnull_buffer(), d, value_type)
2611            }
2612        }),
2613        Int2Vector => strconv::format_legacy_vector(buf, d.unwrap_array().elements(), |buf, d| {
2614            stringify_datum(buf.nonnull_buffer(), d, &SqlScalarType::Int16)
2615        }),
2616        MzTimestamp { .. } => Ok(strconv::format_mz_timestamp(buf, d.unwrap_mz_timestamp())),
2617        Range { element_type } => strconv::format_range(buf, &d.unwrap_range(), |buf, d| match d {
2618            Some(d) => stringify_datum(buf.nonnull_buffer(), *d, element_type),
2619            None => Ok::<_, EvalError>(buf.write_null()),
2620        }),
2621        MzAclItem => Ok(strconv::format_mz_acl_item(buf, d.unwrap_mz_acl_item())),
2622    }
2623}
2624
2625#[sqlfunc]
2626fn position(substring: &str, string: &str) -> Result<i32, EvalError> {
2627    let char_index = string.find(substring);
2628
2629    if let Some(char_index) = char_index {
2630        // find the index in char space
2631        let string_prefix = &string[0..char_index];
2632
2633        let num_prefix_chars = string_prefix.chars().count();
2634        let num_prefix_chars = i32::try_from(num_prefix_chars)
2635            .map_err(|_| EvalError::Int32OutOfRange(num_prefix_chars.to_string().into()))?;
2636
2637        Ok(num_prefix_chars + 1)
2638    } else {
2639        Ok(0)
2640    }
2641}
2642
2643#[sqlfunc]
2644fn strpos(string: &str, substring: &str) -> Result<i32, EvalError> {
2645    position(substring, string)
2646}
2647
2648#[sqlfunc(
2649    propagates_nulls = true,
2650    // `left` is unfortunately not monotonic (at least for negative second arguments),
2651    // because 'aa' < 'z', but `left(_, -1)` makes 'a' > ''.
2652    is_monotone = (false, false)
2653)]
2654fn left<'a>(string: &'a str, b: i32) -> Result<&'a str, EvalError> {
2655    let n = i64::from(b);
2656
2657    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2658
2659    let end_in_bytes = match n.cmp(&0) {
2660        Ordering::Equal => 0,
2661        Ordering::Greater => {
2662            let n = usize::try_from(n).map_err(|_| {
2663                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2664            })?;
2665            // nth from the back
2666            byte_indices.nth(n).unwrap_or(string.len())
2667        }
2668        Ordering::Less => {
2669            let n = usize::try_from(n.abs() - 1).map_err(|_| {
2670                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2671            })?;
2672            byte_indices.rev().nth(n).unwrap_or(0)
2673        }
2674    };
2675
2676    Ok(&string[..end_in_bytes])
2677}
2678
2679#[sqlfunc(propagates_nulls = true)]
2680fn right<'a>(string: &'a str, n: i32) -> Result<&'a str, EvalError> {
2681    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2682
2683    let start_in_bytes = if n == 0 {
2684        string.len()
2685    } else if n > 0 {
2686        let n = usize::try_from(n - 1).map_err(|_| {
2687            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2688        })?;
2689        // nth from the back
2690        byte_indices.rev().nth(n).unwrap_or(0)
2691    } else if n == i32::MIN {
2692        // this seems strange but Postgres behaves like this
2693        0
2694    } else {
2695        let n = n.abs();
2696        let n = usize::try_from(n).map_err(|_| {
2697            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2698        })?;
2699        byte_indices.nth(n).unwrap_or(string.len())
2700    };
2701
2702    Ok(&string[start_in_bytes..])
2703}
2704
2705#[sqlfunc(sqlname = "btrim", propagates_nulls = true)]
2706fn trim<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2707    a.trim_matches(|c| trim_chars.contains(c))
2708}
2709
2710#[sqlfunc(sqlname = "ltrim", propagates_nulls = true)]
2711fn trim_leading<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2712    a.trim_start_matches(|c| trim_chars.contains(c))
2713}
2714
2715#[sqlfunc(sqlname = "rtrim", propagates_nulls = true)]
2716fn trim_trailing<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2717    a.trim_end_matches(|c| trim_chars.contains(c))
2718}
2719
2720#[sqlfunc(
2721    sqlname = "array_length",
2722    propagates_nulls = true,
2723    introduces_nulls = true
2724)]
2725fn array_length<'a>(a: Array<'a>, b: i64) -> Result<Option<i32>, EvalError> {
2726    let i = match usize::try_from(b) {
2727        Ok(0) | Err(_) => return Ok(None),
2728        Ok(n) => n - 1,
2729    };
2730    Ok(match a.dims().into_iter().nth(i) {
2731        None => None,
2732        Some(dim) => Some(
2733            dim.length
2734                .try_into()
2735                .map_err(|_| EvalError::Int32OutOfRange(dim.length.to_string().into()))?,
2736        ),
2737    })
2738}
2739
2740#[sqlfunc(is_infix_op = true)]
2741// TODO(benesch): remove potentially dangerous usage of `as`.
2742#[allow(clippy::as_conversions)]
2743fn array_lower<'a>(a: Array<'a>, i: i64) -> Result<Option<i32>, EvalError> {
2744    if i < 1 {
2745        return Ok(None);
2746    }
2747    a.dims()
2748        .into_iter()
2749        .nth(i as usize - 1)
2750        .map(|dim| {
2751            let (lower, _upper) = dim.dimension_bounds();
2752            lower
2753                .try_into()
2754                .map_err(|_| EvalError::Int32OutOfRange(lower.to_string().into()))
2755        })
2756        .transpose()
2757}
2758
2759#[sqlfunc(
2760    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2761    sqlname = "array_remove",
2762    propagates_nulls = false,
2763    introduces_nulls = false
2764)]
2765fn array_remove<'a>(
2766    arr: Array<'a>,
2767    b: Datum<'a>,
2768    temp_storage: &'a RowArena,
2769) -> Result<Datum<'a>, EvalError> {
2770    // Zero-dimensional arrays are empty by definition
2771    if arr.dims().len() == 0 {
2772        return Ok(Datum::Array(arr));
2773    }
2774
2775    // array_remove only supports one-dimensional arrays
2776    if arr.dims().len() > 1 {
2777        return Err(EvalError::MultidimensionalArrayRemovalNotSupported);
2778    }
2779
2780    let elems: Vec<_> = arr.elements().iter().filter(|v| v != &b).collect();
2781    let mut dims = arr.dims().into_iter().collect::<Vec<_>>();
2782    // This access is safe because `dims` is guaranteed to be non-empty
2783    dims[0] = ArrayDimension {
2784        lower_bound: 1,
2785        length: elems.len(),
2786    };
2787
2788    Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?)
2789}
2790
2791#[sqlfunc(is_infix_op = true)]
2792// TODO(benesch): remove potentially dangerous usage of `as`.
2793#[allow(clippy::as_conversions)]
2794fn array_upper<'a>(a: Array<'a>, i: i64) -> Result<Option<i32>, EvalError> {
2795    if i < 1 {
2796        return Ok(None);
2797    }
2798    a.dims()
2799        .into_iter()
2800        .nth(i as usize - 1)
2801        .map(|dim| {
2802            let (_lower, upper) = dim.dimension_bounds();
2803            upper
2804                .try_into()
2805                .map_err(|_| EvalError::Int32OutOfRange(upper.to_string().into()))
2806        })
2807        .transpose()
2808}
2809
2810#[sqlfunc(
2811    is_infix_op = true,
2812    sqlname = "array_contains",
2813    propagates_nulls = true,
2814    introduces_nulls = false
2815)]
2816fn array_contains<'a>(a: Datum<'a>, array: Array<'a>) -> bool {
2817    array.elements().iter().any(|e| e == a)
2818}
2819
2820#[sqlfunc(is_infix_op = true, sqlname = "@>")]
2821fn array_contains_array<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2822    let a = a.elements();
2823    let b = b.elements();
2824
2825    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
2826    if b.iter().contains(&Datum::Null) {
2827        false
2828    } else {
2829        b.iter()
2830            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
2831    }
2832}
2833
2834#[sqlfunc(is_infix_op = true, sqlname = "<@")]
2835fn array_contains_array_rev<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2836    array_contains_array(b, a)
2837}
2838
2839#[sqlfunc(
2840    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2841    is_infix_op = true,
2842    sqlname = "||",
2843    propagates_nulls = false,
2844    introduces_nulls = false
2845)]
2846fn array_array_concat<'a>(
2847    a: Option<Array<'a>>,
2848    b: Option<Array<'a>>,
2849    temp_storage: &'a RowArena,
2850) -> Result<Option<Array<'a>>, EvalError> {
2851    let Some(a_array) = a else {
2852        return Ok(b);
2853    };
2854    let Some(b_array) = b else {
2855        return Ok(a);
2856    };
2857
2858    let a_dims: Vec<ArrayDimension> = a_array.dims().into_iter().collect();
2859    let b_dims: Vec<ArrayDimension> = b_array.dims().into_iter().collect();
2860
2861    let a_ndims = a_dims.len();
2862    let b_ndims = b_dims.len();
2863
2864    // Per PostgreSQL, if either of the input arrays is zero dimensional,
2865    // the output is the other array, no matter their dimensions.
2866    if a_ndims == 0 {
2867        return Ok(b);
2868    } else if b_ndims == 0 {
2869        return Ok(a);
2870    }
2871
2872    // Postgres supports concatenating arrays of different dimensions,
2873    // as long as one of the arrays has the same type as an element of
2874    // the other array, i.e. `int[2][4] || int[4]` (or `int[4] || int[2][4]`)
2875    // works, because each element of `int[2][4]` is an `int[4]`.
2876    // This check is separate from the one below because Postgres gives a
2877    // specific error message if the number of dimensions differs by more
2878    // than one.
2879    // This cast is safe since MAX_ARRAY_DIMENSIONS is 6
2880    // Can be replaced by .abs_diff once it is stabilized
2881    // TODO(benesch): remove potentially dangerous usage of `as`.
2882    #[allow(clippy::as_conversions)]
2883    if (a_ndims as isize - b_ndims as isize).abs() > 1 {
2884        return Err(EvalError::IncompatibleArrayDimensions {
2885            dims: Some((a_ndims, b_ndims)),
2886        });
2887    }
2888
2889    let mut dims;
2890
2891    // After the checks above, we are certain that:
2892    // - neither array is zero dimensional nor empty
2893    // - both arrays have the same number of dimensions, or differ
2894    //   at most by one.
2895    match a_ndims.cmp(&b_ndims) {
2896        // If both arrays have the same number of dimensions, validate
2897        // that their inner dimensions are the same and concatenate the
2898        // arrays.
2899        Ordering::Equal => {
2900            if &a_dims[1..] != &b_dims[1..] {
2901                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2902            }
2903            dims = vec![ArrayDimension {
2904                lower_bound: a_dims[0].lower_bound,
2905                length: a_dims[0].length + b_dims[0].length,
2906            }];
2907            dims.extend(&a_dims[1..]);
2908        }
2909        // If `a` has less dimensions than `b`, this is an element-array
2910        // concatenation, which requires that `a` has the same dimensions
2911        // as an element of `b`.
2912        Ordering::Less => {
2913            if &a_dims[..] != &b_dims[1..] {
2914                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2915            }
2916            dims = vec![ArrayDimension {
2917                lower_bound: b_dims[0].lower_bound,
2918                // Since `a` is treated as an element of `b`, the length of
2919                // the first dimension of `b` is incremented by one, as `a` is
2920                // non-empty.
2921                length: b_dims[0].length + 1,
2922            }];
2923            dims.extend(a_dims);
2924        }
2925        // If `a` has more dimensions than `b`, this is an array-element
2926        // concatenation, which requires that `b` has the same dimensions
2927        // as an element of `a`.
2928        Ordering::Greater => {
2929            if &a_dims[1..] != &b_dims[..] {
2930                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2931            }
2932            dims = vec![ArrayDimension {
2933                lower_bound: a_dims[0].lower_bound,
2934                // Since `b` is treated as an element of `a`, the length of
2935                // the first dimension of `a` is incremented by one, as `b`
2936                // is non-empty.
2937                length: a_dims[0].length + 1,
2938            }];
2939            dims.extend(b_dims);
2940        }
2941    }
2942
2943    let elems = a_array.elements().iter().chain(b_array.elements().iter());
2944
2945    let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?;
2946    Ok(Some(datum.unwrap_array()))
2947}
2948
2949#[sqlfunc(
2950    is_infix_op = true,
2951    sqlname = "||",
2952    propagates_nulls = false,
2953    introduces_nulls = false
2954)]
2955fn list_list_concat<'a, T: FromDatum<'a>>(
2956    a: Option<DatumList<'a, T>>,
2957    b: Option<DatumList<'a, T>>,
2958    temp_storage: &'a RowArena,
2959) -> Option<DatumList<'a, T>> {
2960    let Some(a) = a else {
2961        return b;
2962    };
2963    let Some(b) = b else {
2964        return Some(a);
2965    };
2966
2967    Some(temp_storage.make_datum_list(a.typed_iter().chain(b.typed_iter())))
2968}
2969
2970#[sqlfunc(is_infix_op = true, sqlname = "||", propagates_nulls = false)]
2971fn list_element_concat<'a, T: FromDatum<'a>>(
2972    a: Option<DatumList<'a, T>>,
2973    b: T,
2974    temp_storage: &'a RowArena,
2975) -> DatumList<'a, T> {
2976    let a_elems = a.into_iter().flat_map(|a| a.typed_iter());
2977    temp_storage.make_datum_list(a_elems.chain(std::iter::once(b)))
2978}
2979
2980// Note that the output type corresponds to the _second_ parameter's input type.
2981#[sqlfunc(is_infix_op = true, sqlname = "||", propagates_nulls = false)]
2982fn element_list_concat<'a, T: FromDatum<'a>>(
2983    a: T,
2984    b: Option<DatumList<'a, T>>,
2985    temp_storage: &'a RowArena,
2986) -> DatumList<'a, T> {
2987    let b_elems = b.into_iter().flat_map(|b| b.typed_iter());
2988    temp_storage.make_datum_list(std::iter::once(a).chain(b_elems))
2989}
2990
2991#[sqlfunc(sqlname = "list_remove")]
2992fn list_remove<'a, T: FromDatum<'a>>(
2993    a: DatumList<'a, T>,
2994    b: T,
2995    temp_storage: &'a RowArena,
2996) -> DatumList<'a, T> {
2997    temp_storage.make_datum_list(a.typed_iter().filter(|elem| *elem != b))
2998}
2999
3000#[sqlfunc(sqlname = "digest")]
3001fn digest_string(to_digest: &str, digest_fn: &str) -> Result<Vec<u8>, EvalError> {
3002    digest_inner(to_digest.as_bytes(), digest_fn)
3003}
3004
3005#[sqlfunc(sqlname = "digest")]
3006fn digest_bytes(to_digest: &[u8], digest_fn: &str) -> Result<Vec<u8>, EvalError> {
3007    digest_inner(to_digest, digest_fn)
3008}
3009
3010fn digest_inner(bytes: &[u8], digest_fn: &str) -> Result<Vec<u8>, EvalError> {
3011    match digest_fn {
3012        "md5" => Ok(Md5::digest(bytes).to_vec()),
3013        "sha1" => Ok(digest::digest(&digest::SHA1_FOR_LEGACY_USE_ONLY, bytes)
3014            .as_ref()
3015            .to_vec()),
3016        "sha224" => Ok(digest::digest(&digest::SHA224, bytes).as_ref().to_vec()),
3017        "sha256" => Ok(digest::digest(&digest::SHA256, bytes).as_ref().to_vec()),
3018        "sha384" => Ok(digest::digest(&digest::SHA384, bytes).as_ref().to_vec()),
3019        "sha512" => Ok(digest::digest(&digest::SHA512, bytes).as_ref().to_vec()),
3020        other => Err(EvalError::InvalidHashAlgorithm(other.into())),
3021    }
3022}
3023
3024#[sqlfunc]
3025fn mz_render_typmod(oid: u32, typmod: i32) -> String {
3026    match Type::from_oid_and_typmod(oid, typmod) {
3027        Ok(typ) => typ.constraint().display_or("").to_string(),
3028        // Match dubious PostgreSQL behavior of outputting the unmodified
3029        // `typmod` when positive if the type OID/typmod is invalid.
3030        Err(_) if typmod >= 0 => format!("({typmod})"),
3031        Err(_) => "".into(),
3032    }
3033}
3034
3035#[cfg(test)]
3036mod test {
3037    use chrono::prelude::*;
3038    use mz_repr::PropDatum;
3039    use proptest::prelude::*;
3040
3041    use super::*;
3042    use crate::{Eval, MirScalarExpr};
3043
3044    #[mz_ore::test]
3045    fn add_interval_months() {
3046        let dt = ym(2000, 1);
3047
3048        assert_eq!(add_timestamp_months(&*dt, 0).unwrap(), dt);
3049        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 2));
3050        assert_eq!(add_timestamp_months(&*dt, 12).unwrap(), ym(2001, 1));
3051        assert_eq!(add_timestamp_months(&*dt, 13).unwrap(), ym(2001, 2));
3052        assert_eq!(add_timestamp_months(&*dt, 24).unwrap(), ym(2002, 1));
3053        assert_eq!(add_timestamp_months(&*dt, 30).unwrap(), ym(2002, 7));
3054
3055        // and negatives
3056        assert_eq!(add_timestamp_months(&*dt, -1).unwrap(), ym(1999, 12));
3057        assert_eq!(add_timestamp_months(&*dt, -12).unwrap(), ym(1999, 1));
3058        assert_eq!(add_timestamp_months(&*dt, -13).unwrap(), ym(1998, 12));
3059        assert_eq!(add_timestamp_months(&*dt, -24).unwrap(), ym(1998, 1));
3060        assert_eq!(add_timestamp_months(&*dt, -30).unwrap(), ym(1997, 7));
3061
3062        // and going over a year boundary by less than a year
3063        let dt = ym(1999, 12);
3064        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 1));
3065        let end_of_month_dt = NaiveDate::from_ymd_opt(1999, 12, 31)
3066            .unwrap()
3067            .and_hms_opt(9, 9, 9)
3068            .unwrap();
3069        assert_eq!(
3070            // leap year
3071            add_timestamp_months(&end_of_month_dt, 2).unwrap(),
3072            NaiveDate::from_ymd_opt(2000, 2, 29)
3073                .unwrap()
3074                .and_hms_opt(9, 9, 9)
3075                .unwrap()
3076                .try_into()
3077                .unwrap(),
3078        );
3079        assert_eq!(
3080            // not leap year
3081            add_timestamp_months(&end_of_month_dt, 14).unwrap(),
3082            NaiveDate::from_ymd_opt(2001, 2, 28)
3083                .unwrap()
3084                .and_hms_opt(9, 9, 9)
3085                .unwrap()
3086                .try_into()
3087                .unwrap(),
3088        );
3089    }
3090
3091    fn ym(year: i32, month: u32) -> CheckedTimestamp<NaiveDateTime> {
3092        NaiveDate::from_ymd_opt(year, month, 1)
3093            .unwrap()
3094            .and_hms_opt(9, 9, 9)
3095            .unwrap()
3096            .try_into()
3097            .unwrap()
3098    }
3099
3100    #[mz_ore::test]
3101    fn array_lower_upper_respect_lower_bound() {
3102        use mz_repr::adt::array::ArrayDimension;
3103        use mz_repr::{Datum, RowArena};
3104
3105        let arena = RowArena::new();
3106
3107        // Builds a one-dimensional array with the given lower bound and length,
3108        // then returns (array_lower(_, 1), array_upper(_, 1)).
3109        let bounds = |lower_bound: isize, length: usize| {
3110            let dims = [ArrayDimension {
3111                lower_bound,
3112                length,
3113            }];
3114            let elems = vec![Datum::Int32(0); length];
3115            let datum = arena.make_datum(|packer| packer.try_push_array(&dims, elems).unwrap());
3116            let arr = match datum {
3117                Datum::Array(arr) => arr,
3118                other => panic!("expected array, got {other:?}"),
3119            };
3120            (array_lower(arr, 1).unwrap(), array_upper(arr, 1).unwrap())
3121        };
3122
3123        // Default lower bound of 1: array_fill(0, ARRAY[3]).
3124        assert_eq!(bounds(1, 3), (Some(1), Some(3)));
3125        // Lower bound of 5: array_fill(0, ARRAY[3], ARRAY[5]) => [5:7].
3126        assert_eq!(bounds(5, 3), (Some(5), Some(7)));
3127        // Negative lower bound: array_fill(0, ARRAY[3], ARRAY[-3]) => [-3:-1].
3128        assert_eq!(bounds(-3, 3), (Some(-3), Some(-1)));
3129
3130        // Out-of-range dimensions return None rather than the bound.
3131        let dims = [ArrayDimension {
3132            lower_bound: 5,
3133            length: 3,
3134        }];
3135        let elems = vec![Datum::Int32(0); 3];
3136        let datum = arena.make_datum(|packer| packer.try_push_array(&dims, elems).unwrap());
3137        let arr = match datum {
3138            Datum::Array(arr) => arr,
3139            other => panic!("expected array, got {other:?}"),
3140        };
3141        assert_eq!(array_lower(arr, 0).unwrap(), None);
3142        assert_eq!(array_upper(arr, 0).unwrap(), None);
3143        assert_eq!(array_lower(arr, 2).unwrap(), None);
3144        assert_eq!(array_upper(arr, 2).unwrap(), None);
3145    }
3146
3147    #[mz_ore::test]
3148    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
3149    fn test_is_monotone() {
3150        use proptest::prelude::*;
3151
3152        /// Asserts that the function is either monotonically increasing or decreasing over
3153        /// the given sets of arguments.
3154        fn assert_monotone<'a, const N: usize>(
3155            expr: &MirScalarExpr,
3156            arena: &'a RowArena,
3157            datums: &[[Datum<'a>; N]],
3158        ) {
3159            // TODO: assertions for nulls, errors
3160            let Ok(results) = datums
3161                .iter()
3162                .map(|args| expr.eval(args.as_slice(), arena))
3163                .collect::<Result<Vec<_>, _>>()
3164            else {
3165                return;
3166            };
3167
3168            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
3169            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
3170            assert!(
3171                forward || reverse,
3172                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
3173            );
3174        }
3175
3176        fn proptest_binary<'a>(
3177            func: BinaryFunc,
3178            arena: &'a RowArena,
3179            left: impl Strategy<Value = PropDatum>,
3180            right: impl Strategy<Value = PropDatum>,
3181        ) {
3182            let (left_monotone, right_monotone) = func.is_monotone();
3183            let expr = MirScalarExpr::CallBinary {
3184                func,
3185                expr1: Box::new(MirScalarExpr::column(0)),
3186                expr2: Box::new(MirScalarExpr::column(1)),
3187            };
3188            proptest!(|(
3189                mut left in proptest::array::uniform3(left),
3190                mut right in proptest::array::uniform3(right),
3191            )| {
3192                left.sort();
3193                right.sort();
3194                if left_monotone {
3195                    for r in &right {
3196                        let args: Vec<[_; 2]> = left
3197                            .iter()
3198                            .map(|l| [Datum::from(l), Datum::from(r)])
3199                            .collect();
3200                        assert_monotone(&expr, arena, &args);
3201                    }
3202                }
3203                if right_monotone {
3204                    for l in &left {
3205                        let args: Vec<[_; 2]> = right
3206                            .iter()
3207                            .map(|r| [Datum::from(l), Datum::from(r)])
3208                            .collect();
3209                        assert_monotone(&expr, arena, &args);
3210                    }
3211                }
3212            });
3213        }
3214
3215        let interesting_strs: Vec<_> = SqlScalarType::String.interesting_datums().collect();
3216        let str_datums = proptest::strategy::Union::new([
3217            proptest::string::string_regex("[A-Z]{0,10}")
3218                .expect("valid regex")
3219                .prop_map(|s| PropDatum::String(s.to_string()))
3220                .boxed(),
3221            (0..interesting_strs.len())
3222                .prop_map(move |i| {
3223                    let Datum::String(val) = interesting_strs[i] else {
3224                        unreachable!("interesting strings has non-strings")
3225                    };
3226                    PropDatum::String(val.to_string())
3227                })
3228                .boxed(),
3229        ]);
3230
3231        let interesting_i32s: Vec<Datum<'static>> =
3232            SqlScalarType::Int32.interesting_datums().collect();
3233        let i32_datums = proptest::strategy::Union::new([
3234            any::<i32>().prop_map(PropDatum::Int32).boxed(),
3235            (0..interesting_i32s.len())
3236                .prop_map(move |i| {
3237                    let Datum::Int32(val) = interesting_i32s[i] else {
3238                        unreachable!("interesting int32 has non-i32s")
3239                    };
3240                    PropDatum::Int32(val)
3241                })
3242                .boxed(),
3243            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
3244        ]);
3245
3246        let arena = RowArena::new();
3247
3248        // It would be interesting to test all funcs here, but we currently need to hardcode
3249        // the generators for the argument types, which makes this tedious. Choose an interesting
3250        // subset for now.
3251        proptest_binary(
3252            BinaryFunc::AddInt32(AddInt32),
3253            &arena,
3254            &i32_datums,
3255            &i32_datums,
3256        );
3257        proptest_binary(SubInt32.into(), &arena, &i32_datums, &i32_datums);
3258        proptest_binary(MulInt32.into(), &arena, &i32_datums, &i32_datums);
3259        proptest_binary(DivInt32.into(), &arena, &i32_datums, &i32_datums);
3260        proptest_binary(TextConcatBinary.into(), &arena, &str_datums, &str_datums);
3261        proptest_binary(Left.into(), &arena, &str_datums, &i32_datums);
3262    }
3263}