Skip to main content

mz_expr/scalar/
func.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14use std::borrow::Cow;
15use std::cmp::Ordering;
16use std::convert::{TryFrom, TryInto};
17use std::str::FromStr;
18use std::{iter, str};
19
20use ::encoding::DecoderTrap;
21use ::encoding::label::encoding_from_whatwg_label;
22use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, TimeZone, Timelike, Utc};
23use chrono_tz::{OffsetComponents, OffsetName, Tz};
24use dec::OrderedDecimal;
25use itertools::Itertools;
26use md5::{Digest, Md5};
27use mz_expr_derive::sqlfunc;
28use mz_ore::cast::{self, CastFrom};
29use mz_ore::fmt::FormatBuffer;
30use mz_ore::lex::LexBuf;
31use mz_ore::option::OptionExt;
32use mz_pgrepr::Type;
33use mz_pgtz::timezone::{Timezone, TimezoneSpec};
34use mz_repr::adt::array::{Array, ArrayDimension};
35use mz_repr::adt::date::Date;
36use mz_repr::adt::interval::{Interval, RoundBehavior};
37use mz_repr::adt::jsonb::JsonbRef;
38use mz_repr::adt::mz_acl_item::{AclMode, MzAclItem};
39use mz_repr::adt::numeric::{self, Numeric};
40use mz_repr::adt::range::Range;
41use mz_repr::adt::regex::Regex;
42use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
43use mz_repr::{
44    ArrayRustType, Datum, DatumList, DatumMap, ExcludeNull, InputDatumType, OutputDatumType, Row,
45    RowArena, SqlScalarType, strconv,
46};
47use mz_sql_parser::ast::display::FormatMode;
48use mz_sql_pretty::{PrettyConfig, pretty_str};
49use num::traits::CheckedNeg;
50use sha1::Sha1;
51use sha2::{Sha224, Sha256, Sha384, Sha512};
52use subtle::ConstantTimeEq;
53
54use crate::scalar::func::format::DateTimeFormat;
55use crate::{EvalError, like_pattern};
56
57#[macro_use]
58mod macros;
59mod binary;
60mod encoding;
61pub(crate) mod format;
62pub(crate) mod impls;
63mod unary;
64mod unmaterializable;
65mod variadic;
66
67pub use binary::BinaryFunc;
68pub use impls::*;
69pub use unary::{EagerUnaryFunc, LazyUnaryFunc, UnaryFunc};
70pub use unmaterializable::UnmaterializableFunc;
71pub use variadic::VariadicFunc;
72
73/// The maximum size of the result strings of certain string functions, such as `repeat` and `lpad`.
74/// Chosen to be the smallest number to keep our tests passing without changing. 100MiB is probably
75/// higher than what we want, but it's better than no limit.
76///
77/// Note: This number appears in our user-facing documentation in the function reference for every
78/// function where it applies.
79pub const MAX_STRING_FUNC_RESULT_BYTES: usize = 1024 * 1024 * 100;
80
81pub fn jsonb_stringify<'a>(a: Datum<'a>, temp_storage: &'a RowArena) -> Option<&'a str> {
82    match a {
83        Datum::JsonNull => None,
84        Datum::String(s) => Some(s),
85        _ => {
86            let s = cast_jsonb_to_string(JsonbRef::from_datum(a));
87            Some(temp_storage.push_string(s))
88        }
89    }
90}
91
92#[sqlfunc(
93    is_monotone = "(true, true)",
94    is_infix_op = true,
95    sqlname = "+",
96    propagates_nulls = true
97)]
98fn add_int16(a: i16, b: i16) -> Result<i16, EvalError> {
99    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
100}
101
102#[sqlfunc(
103    is_monotone = "(true, true)",
104    is_infix_op = true,
105    sqlname = "+",
106    propagates_nulls = true
107)]
108fn add_int32(a: i32, b: i32) -> Result<i32, EvalError> {
109    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
110}
111
112#[sqlfunc(
113    is_monotone = "(true, true)",
114    is_infix_op = true,
115    sqlname = "+",
116    propagates_nulls = true
117)]
118fn add_int64(a: i64, b: i64) -> Result<i64, EvalError> {
119    a.checked_add(b).ok_or(EvalError::NumericFieldOverflow)
120}
121
122#[sqlfunc(
123    is_monotone = "(true, true)",
124    is_infix_op = true,
125    sqlname = "+",
126    propagates_nulls = true
127)]
128fn add_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
129    a.checked_add(b)
130        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} + {b}").into()))
131}
132
133#[sqlfunc(
134    is_monotone = "(true, true)",
135    is_infix_op = true,
136    sqlname = "+",
137    propagates_nulls = true
138)]
139fn add_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
140    a.checked_add(b)
141        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} + {b}").into()))
142}
143
144#[sqlfunc(
145    is_monotone = "(true, true)",
146    is_infix_op = true,
147    sqlname = "+",
148    propagates_nulls = true
149)]
150fn add_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
151    a.checked_add(b)
152        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} + {b}").into()))
153}
154
155#[sqlfunc(
156    is_monotone = "(true, true)",
157    is_infix_op = true,
158    sqlname = "+",
159    propagates_nulls = true
160)]
161fn add_float32(a: f32, b: f32) -> Result<f32, EvalError> {
162    let sum = a + b;
163    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
164        Err(EvalError::FloatOverflow)
165    } else {
166        Ok(sum)
167    }
168}
169
170#[sqlfunc(
171    is_monotone = "(true, true)",
172    is_infix_op = true,
173    sqlname = "+",
174    propagates_nulls = true
175)]
176fn add_float64(a: f64, b: f64) -> Result<f64, EvalError> {
177    let sum = a + b;
178    if sum.is_infinite() && !a.is_infinite() && !b.is_infinite() {
179        Err(EvalError::FloatOverflow)
180    } else {
181        Ok(sum)
182    }
183}
184
185#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
186fn add_timestamp_interval(
187    a: CheckedTimestamp<NaiveDateTime>,
188    b: Interval,
189) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
190    add_timestamplike_interval(a, b)
191}
192
193#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
194fn add_timestamp_tz_interval(
195    a: CheckedTimestamp<DateTime<Utc>>,
196    b: Interval,
197) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
198    add_timestamplike_interval(a, b)
199}
200
201fn add_timestamplike_interval<T>(
202    a: CheckedTimestamp<T>,
203    b: Interval,
204) -> Result<CheckedTimestamp<T>, EvalError>
205where
206    T: TimestampLike,
207{
208    let dt = a.date_time();
209    let dt = add_timestamp_months(&dt, b.months)?;
210    let dt = dt
211        .checked_add_signed(b.duration_as_chrono())
212        .ok_or(EvalError::TimestampOutOfRange)?;
213    Ok(CheckedTimestamp::from_timestamplike(T::from_date_time(dt))?)
214}
215
216#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
217fn sub_timestamp_interval(
218    a: CheckedTimestamp<NaiveDateTime>,
219    b: Interval,
220) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
221    sub_timestamplike_interval(a, b)
222}
223
224#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
225fn sub_timestamp_tz_interval(
226    a: CheckedTimestamp<DateTime<Utc>>,
227    b: Interval,
228) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
229    sub_timestamplike_interval(a, b)
230}
231
232fn sub_timestamplike_interval<T>(
233    a: CheckedTimestamp<T>,
234    b: Interval,
235) -> Result<CheckedTimestamp<T>, EvalError>
236where
237    T: TimestampLike,
238{
239    neg_interval_inner(b).and_then(|i| add_timestamplike_interval(a, i))
240}
241
242#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
243fn add_date_time(
244    date: Date,
245    time: chrono::NaiveTime,
246) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
247    let dt = NaiveDate::from(date)
248        .and_hms_nano_opt(time.hour(), time.minute(), time.second(), time.nanosecond())
249        .unwrap();
250    Ok(CheckedTimestamp::from_timestamplike(dt)?)
251}
252
253#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "+")]
254fn add_date_interval(
255    date: Date,
256    interval: Interval,
257) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
258    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
259    let dt = add_timestamp_months(&dt, interval.months)?;
260    let dt = dt
261        .checked_add_signed(interval.duration_as_chrono())
262        .ok_or(EvalError::TimestampOutOfRange)?;
263    Ok(CheckedTimestamp::from_timestamplike(dt)?)
264}
265
266#[sqlfunc(
267    // <time> + <interval> wraps!
268    is_monotone = "(false, false)",
269    is_infix_op = true,
270    sqlname = "+",
271    propagates_nulls = true
272)]
273fn add_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
274    let (t, _) = time.overflowing_add_signed(interval.duration_as_chrono());
275    t
276}
277
278#[sqlfunc(
279    is_monotone = "(true, false)",
280    output_type = "Numeric",
281    sqlname = "round",
282    propagates_nulls = true
283)]
284fn round_numeric_binary(a: OrderedDecimal<Numeric>, mut b: i32) -> Result<Numeric, EvalError> {
285    let mut a = a.0;
286    let mut cx = numeric::cx_datum();
287    let a_exp = a.exponent();
288    if a_exp > 0 && b > 0 || a_exp < 0 && -a_exp < b {
289        // This condition indicates:
290        // - a is a value without a decimal point, b is a positive number
291        // - a has a decimal point, but b is larger than its scale
292        // In both of these situations, right-pad the number with zeroes, which // is most easily done with rescale.
293
294        // Ensure rescale doesn't exceed max precision by putting a ceiling on
295        // b equal to the maximum remaining scale the value can support.
296        let max_remaining_scale = u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
297            - (numeric::get_precision(&a) - numeric::get_scale(&a));
298        b = match i32::try_from(max_remaining_scale) {
299            Ok(max_remaining_scale) => std::cmp::min(b, max_remaining_scale),
300            Err(_) => b,
301        };
302        cx.rescale(&mut a, &numeric::Numeric::from(-b));
303    } else {
304        // To avoid invalid operations, clamp b to be within 1 more than the
305        // precision limit.
306        const MAX_P_LIMIT: i32 = 1 + cast::u8_to_i32(numeric::NUMERIC_DATUM_MAX_PRECISION);
307        b = std::cmp::min(MAX_P_LIMIT, b);
308        b = std::cmp::max(-MAX_P_LIMIT, b);
309        let mut b = numeric::Numeric::from(b);
310        // Shift by 10^b; this put digit to round to in the one's place.
311        cx.scaleb(&mut a, &b);
312        cx.round(&mut a);
313        // Negate exponent for shift back
314        cx.neg(&mut b);
315        cx.scaleb(&mut a, &b);
316    }
317
318    if cx.status().overflow() {
319        Err(EvalError::FloatOverflow)
320    } else if a.is_zero() {
321        // simpler than handling cases where exponent has gotten set to some
322        // value greater than the max precision, but all significant digits
323        // were rounded away.
324        Ok(numeric::Numeric::zero())
325    } else {
326        numeric::munge_numeric(&mut a).unwrap();
327        Ok(a)
328    }
329}
330
331#[sqlfunc(sqlname = "convert_from", propagates_nulls = true)]
332fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
333    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
334    // which the encoding library uses[3].
335    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
336    // [2]: https://encoding.spec.whatwg.org/
337    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
338    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
339
340    // Supporting other encodings is tracked by database-issues#797.
341    if encoding_from_whatwg_label(&encoding_name).map(|e| e.name()) != Some("utf-8") {
342        return Err(EvalError::InvalidEncodingName(encoding_name));
343    }
344
345    match str::from_utf8(a) {
346        Ok(from) => Ok(from),
347        Err(e) => Err(EvalError::InvalidByteSequence {
348            byte_sequence: e.to_string().into(),
349            encoding_name,
350        }),
351    }
352}
353
354#[sqlfunc]
355fn encode(bytes: &[u8], format: &str) -> Result<String, EvalError> {
356    let format = encoding::lookup_format(format)?;
357    Ok(format.encode(bytes))
358}
359
360#[sqlfunc]
361fn decode(string: &str, format: &str) -> Result<Vec<u8>, EvalError> {
362    let format = encoding::lookup_format(format)?;
363    let out = format.decode(string)?;
364    if out.len() > MAX_STRING_FUNC_RESULT_BYTES {
365        Err(EvalError::LengthTooLarge)
366    } else {
367        Ok(out)
368    }
369}
370
371#[sqlfunc(sqlname = "length", propagates_nulls = true)]
372fn encoded_bytes_char_length(a: &[u8], b: &str) -> Result<i32, EvalError> {
373    // Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
374    // which the encoding library uses[3].
375    // [1]: https://www.postgresql.org/docs/9.5/multibyte.html
376    // [2]: https://encoding.spec.whatwg.org/
377    // [3]: https://github.com/lifthrasiir/rust-encoding/blob/4e79c35ab6a351881a86dbff565c4db0085cc113/src/label.rs
378    let encoding_name = b.to_lowercase().replace('_', "-").into_boxed_str();
379
380    let enc = match encoding_from_whatwg_label(&encoding_name) {
381        Some(enc) => enc,
382        None => return Err(EvalError::InvalidEncodingName(encoding_name)),
383    };
384
385    let decoded_string = match enc.decode(a, DecoderTrap::Strict) {
386        Ok(s) => s,
387        Err(e) => {
388            return Err(EvalError::InvalidByteSequence {
389                byte_sequence: e.into(),
390                encoding_name,
391            });
392        }
393    };
394
395    let count = decoded_string.chars().count();
396    i32::try_from(count).map_err(|_| EvalError::Int32OutOfRange(count.to_string().into()))
397}
398
399// TODO(benesch): remove potentially dangerous usage of `as`.
400#[allow(clippy::as_conversions)]
401pub fn add_timestamp_months<T: TimestampLike>(
402    dt: &T,
403    mut months: i32,
404) -> Result<CheckedTimestamp<T>, EvalError> {
405    if months == 0 {
406        return Ok(CheckedTimestamp::from_timestamplike(dt.clone())?);
407    }
408
409    let (mut year, mut month, mut day) = (dt.year(), dt.month0() as i32, dt.day());
410    let years = months / 12;
411    year = year
412        .checked_add(years)
413        .ok_or(EvalError::TimestampOutOfRange)?;
414
415    months %= 12;
416    // positive modulus is easier to reason about
417    if months < 0 {
418        year -= 1;
419        months += 12;
420    }
421    year += (month + months) / 12;
422    month = (month + months) % 12;
423    // account for dt.month0
424    month += 1;
425
426    // handle going from January 31st to February by saturation
427    let mut new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
428    while new_d.is_none() {
429        // If we have decremented day past 28 and are still receiving `None`,
430        // then we have generally overflowed `NaiveDate`.
431        if day < 28 {
432            return Err(EvalError::TimestampOutOfRange);
433        }
434        day -= 1;
435        new_d = chrono::NaiveDate::from_ymd_opt(year, month as u32, day);
436    }
437    let new_d = new_d.unwrap();
438
439    // Neither postgres nor mysql support leap seconds, so this should be safe.
440    //
441    // Both my testing and https://dba.stackexchange.com/a/105829 support the
442    // idea that we should ignore leap seconds
443    let new_dt = new_d
444        .and_hms_nano_opt(dt.hour(), dt.minute(), dt.second(), dt.nanosecond())
445        .unwrap();
446    let new_dt = T::from_date_time(new_dt);
447    Ok(CheckedTimestamp::from_timestamplike(new_dt)?)
448}
449
450#[sqlfunc(
451    is_monotone = "(true, true)",
452    is_infix_op = true,
453    sqlname = "+",
454    propagates_nulls = true
455)]
456fn add_numeric(
457    a: OrderedDecimal<Numeric>,
458    b: OrderedDecimal<Numeric>,
459) -> Result<Numeric, EvalError> {
460    let mut cx = numeric::cx_datum();
461    let mut a = a.0;
462    cx.add(&mut a, &b.0);
463    if cx.status().overflow() {
464        Err(EvalError::FloatOverflow)
465    } else {
466        Ok(a)
467    }
468}
469
470#[sqlfunc(
471    is_monotone = "(true, true)",
472    is_infix_op = true,
473    sqlname = "+",
474    propagates_nulls = true
475)]
476fn add_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
477    a.checked_add(&b)
478        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} + {b}").into()))
479}
480
481#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
482fn bit_and_int16(a: i16, b: i16) -> i16 {
483    a & b
484}
485
486#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
487fn bit_and_int32(a: i32, b: i32) -> i32 {
488    a & b
489}
490
491#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
492fn bit_and_int64(a: i64, b: i64) -> i64 {
493    a & b
494}
495
496#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
497fn bit_and_uint16(a: u16, b: u16) -> u16 {
498    a & b
499}
500
501#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
502fn bit_and_uint32(a: u32, b: u32) -> u32 {
503    a & b
504}
505
506#[sqlfunc(is_infix_op = true, sqlname = "&", propagates_nulls = true)]
507fn bit_and_uint64(a: u64, b: u64) -> u64 {
508    a & b
509}
510
511#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
512fn bit_or_int16(a: i16, b: i16) -> i16 {
513    a | b
514}
515
516#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
517fn bit_or_int32(a: i32, b: i32) -> i32 {
518    a | b
519}
520
521#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
522fn bit_or_int64(a: i64, b: i64) -> i64 {
523    a | b
524}
525
526#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
527fn bit_or_uint16(a: u16, b: u16) -> u16 {
528    a | b
529}
530
531#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
532fn bit_or_uint32(a: u32, b: u32) -> u32 {
533    a | b
534}
535
536#[sqlfunc(is_infix_op = true, sqlname = "|", propagates_nulls = true)]
537fn bit_or_uint64(a: u64, b: u64) -> u64 {
538    a | b
539}
540
541#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
542fn bit_xor_int16(a: i16, b: i16) -> i16 {
543    a ^ b
544}
545
546#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
547fn bit_xor_int32(a: i32, b: i32) -> i32 {
548    a ^ b
549}
550
551#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
552fn bit_xor_int64(a: i64, b: i64) -> i64 {
553    a ^ b
554}
555
556#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
557fn bit_xor_uint16(a: u16, b: u16) -> u16 {
558    a ^ b
559}
560
561#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
562fn bit_xor_uint32(a: u32, b: u32) -> u32 {
563    a ^ b
564}
565
566#[sqlfunc(is_infix_op = true, sqlname = "#", propagates_nulls = true)]
567fn bit_xor_uint64(a: u64, b: u64) -> u64 {
568    a ^ b
569}
570
571#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
572// TODO(benesch): remove potentially dangerous usage of `as`.
573#[allow(clippy::as_conversions)]
574fn bit_shift_left_int16(a: i16, b: i32) -> i16 {
575    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
576    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
577    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
578    let lhs: i32 = a as i32;
579    let rhs: u32 = b as u32;
580    lhs.wrapping_shl(rhs) as i16
581}
582
583#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
584// TODO(benesch): remove potentially dangerous usage of `as`.
585#[allow(clippy::as_conversions)]
586fn bit_shift_left_int32(lhs: i32, rhs: i32) -> i32 {
587    let rhs = rhs as u32;
588    lhs.wrapping_shl(rhs)
589}
590
591#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
592// TODO(benesch): remove potentially dangerous usage of `as`.
593#[allow(clippy::as_conversions)]
594fn bit_shift_left_int64(lhs: i64, rhs: i32) -> i64 {
595    let rhs = rhs as u32;
596    lhs.wrapping_shl(rhs)
597}
598
599#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
600// TODO(benesch): remove potentially dangerous usage of `as`.
601#[allow(clippy::as_conversions)]
602fn bit_shift_left_uint16(a: u16, b: u32) -> u16 {
603    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
604    // when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
605    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
606    let lhs: u32 = a as u32;
607    let rhs: u32 = b;
608    lhs.wrapping_shl(rhs) as u16
609}
610
611#[sqlfunc(is_infix_op = true, sqlname = "<<", propagates_nulls = true)]
612fn bit_shift_left_uint32(a: u32, b: u32) -> u32 {
613    let lhs = a;
614    let rhs = b;
615    lhs.wrapping_shl(rhs)
616}
617
618#[sqlfunc(
619    output_type = "u64",
620    is_infix_op = true,
621    sqlname = "<<",
622    propagates_nulls = true
623)]
624fn bit_shift_left_uint64(lhs: u64, rhs: u32) -> u64 {
625    lhs.wrapping_shl(rhs)
626}
627
628#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
629// TODO(benesch): remove potentially dangerous usage of `as`.
630#[allow(clippy::as_conversions)]
631fn bit_shift_right_int16(lhs: i16, rhs: i32) -> i16 {
632    // widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
633    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
634    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
635    let lhs = lhs as i32;
636    let rhs = rhs as u32;
637    lhs.wrapping_shr(rhs) as i16
638}
639
640#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
641// TODO(benesch): remove potentially dangerous usage of `as`.
642#[allow(clippy::as_conversions)]
643fn bit_shift_right_int32(lhs: i32, rhs: i32) -> i32 {
644    lhs.wrapping_shr(rhs as u32)
645}
646
647#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
648// TODO(benesch): remove potentially dangerous usage of `as`.
649#[allow(clippy::as_conversions)]
650fn bit_shift_right_int64(lhs: i64, rhs: i32) -> i64 {
651    lhs.wrapping_shr(rhs as u32)
652}
653
654#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
655// TODO(benesch): remove potentially dangerous usage of `as`.
656#[allow(clippy::as_conversions)]
657fn bit_shift_right_uint16(lhs: u16, rhs: u32) -> u16 {
658    // widen to u32 and then cast back to u16 in order emulate the C promotion rules used in by Postgres
659    // when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
660    // see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
661    let lhs = lhs as u32;
662    lhs.wrapping_shr(rhs) as u16
663}
664
665#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
666fn bit_shift_right_uint32(lhs: u32, rhs: u32) -> u32 {
667    lhs.wrapping_shr(rhs)
668}
669
670#[sqlfunc(is_infix_op = true, sqlname = ">>", propagates_nulls = true)]
671fn bit_shift_right_uint64(lhs: u64, rhs: u32) -> u64 {
672    lhs.wrapping_shr(rhs)
673}
674
675#[sqlfunc(
676    is_monotone = "(true, true)",
677    is_infix_op = true,
678    sqlname = "-",
679    propagates_nulls = true
680)]
681fn sub_int16(a: i16, b: i16) -> Result<i16, EvalError> {
682    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
683}
684
685#[sqlfunc(
686    is_monotone = "(true, true)",
687    is_infix_op = true,
688    sqlname = "-",
689    propagates_nulls = true
690)]
691fn sub_int32(a: i32, b: i32) -> Result<i32, EvalError> {
692    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
693}
694
695#[sqlfunc(
696    is_monotone = "(true, true)",
697    is_infix_op = true,
698    sqlname = "-",
699    propagates_nulls = true
700)]
701fn sub_int64(a: i64, b: i64) -> Result<i64, EvalError> {
702    a.checked_sub(b).ok_or(EvalError::NumericFieldOverflow)
703}
704
705#[sqlfunc(
706    is_monotone = "(true, true)",
707    is_infix_op = true,
708    sqlname = "-",
709    propagates_nulls = true
710)]
711fn sub_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
712    a.checked_sub(b)
713        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} - {b}").into()))
714}
715
716#[sqlfunc(
717    is_monotone = "(true, true)",
718    is_infix_op = true,
719    sqlname = "-",
720    propagates_nulls = true
721)]
722fn sub_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
723    a.checked_sub(b)
724        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} - {b}").into()))
725}
726
727#[sqlfunc(
728    is_monotone = "(true, true)",
729    is_infix_op = true,
730    sqlname = "-",
731    propagates_nulls = true
732)]
733fn sub_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
734    a.checked_sub(b)
735        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} - {b}").into()))
736}
737
738#[sqlfunc(
739    is_monotone = "(true, true)",
740    is_infix_op = true,
741    sqlname = "-",
742    propagates_nulls = true
743)]
744fn sub_float32(a: f32, b: f32) -> Result<f32, EvalError> {
745    let difference = a - b;
746    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
747        Err(EvalError::FloatOverflow)
748    } else {
749        Ok(difference)
750    }
751}
752
753#[sqlfunc(
754    is_monotone = "(true, true)",
755    is_infix_op = true,
756    sqlname = "-",
757    propagates_nulls = true
758)]
759fn sub_float64(a: f64, b: f64) -> Result<f64, EvalError> {
760    let difference = a - b;
761    if difference.is_infinite() && !a.is_infinite() && !b.is_infinite() {
762        Err(EvalError::FloatOverflow)
763    } else {
764        Ok(difference)
765    }
766}
767
768#[sqlfunc(
769    is_monotone = "(true, true)",
770    is_infix_op = true,
771    sqlname = "-",
772    propagates_nulls = true
773)]
774fn sub_numeric(
775    a: OrderedDecimal<Numeric>,
776    b: OrderedDecimal<Numeric>,
777) -> Result<Numeric, EvalError> {
778    let mut cx = numeric::cx_datum();
779    let mut a = a.0;
780    cx.sub(&mut a, &b.0);
781    if cx.status().overflow() {
782        Err(EvalError::FloatOverflow)
783    } else {
784        Ok(a)
785    }
786}
787
788#[sqlfunc(
789    is_monotone = "(true, true)",
790    output_type = "Interval",
791    sqlname = "age",
792    propagates_nulls = true
793)]
794fn age_timestamp(
795    a: CheckedTimestamp<chrono::NaiveDateTime>,
796    b: CheckedTimestamp<chrono::NaiveDateTime>,
797) -> Result<Interval, EvalError> {
798    Ok(a.age(&b)?)
799}
800
801#[sqlfunc(is_monotone = "(true, true)", sqlname = "age", propagates_nulls = true)]
802fn age_timestamp_tz(
803    a: CheckedTimestamp<chrono::DateTime<Utc>>,
804    b: CheckedTimestamp<chrono::DateTime<Utc>>,
805) -> Result<Interval, EvalError> {
806    Ok(a.age(&b)?)
807}
808
809#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
810fn sub_timestamp(
811    a: CheckedTimestamp<NaiveDateTime>,
812    b: CheckedTimestamp<NaiveDateTime>,
813) -> Result<Interval, EvalError> {
814    Interval::from_chrono_duration(a - b)
815        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
816}
817
818#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
819fn sub_timestamp_tz(
820    a: CheckedTimestamp<chrono::DateTime<Utc>>,
821    b: CheckedTimestamp<chrono::DateTime<Utc>>,
822) -> Result<Interval, EvalError> {
823    Interval::from_chrono_duration(a - b)
824        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
825}
826
827#[sqlfunc(
828    is_monotone = "(true, true)",
829    is_infix_op = true,
830    sqlname = "-",
831    propagates_nulls = true
832)]
833fn sub_date(a: Date, b: Date) -> i32 {
834    a - b
835}
836
837#[sqlfunc(is_monotone = "(true, true)", is_infix_op = true, sqlname = "-")]
838fn sub_time(a: chrono::NaiveTime, b: chrono::NaiveTime) -> Result<Interval, EvalError> {
839    Interval::from_chrono_duration(a - b)
840        .map_err(|e| EvalError::IntervalOutOfRange(e.to_string().into()))
841}
842
843#[sqlfunc(
844    is_monotone = "(true, true)",
845    output_type = "Interval",
846    is_infix_op = true,
847    sqlname = "-",
848    propagates_nulls = true
849)]
850fn sub_interval(a: Interval, b: Interval) -> Result<Interval, EvalError> {
851    b.checked_neg()
852        .and_then(|b| b.checked_add(&a))
853        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} - {b}").into()))
854}
855
856#[sqlfunc(
857    is_monotone = "(true, true)",
858    is_infix_op = true,
859    sqlname = "-",
860    propagates_nulls = true
861)]
862fn sub_date_interval(
863    date: Date,
864    interval: Interval,
865) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
866    let dt = NaiveDate::from(date).and_hms_opt(0, 0, 0).unwrap();
867    let dt = interval
868        .months
869        .checked_neg()
870        .ok_or_else(|| EvalError::IntervalOutOfRange(interval.months.to_string().into()))
871        .and_then(|months| add_timestamp_months(&dt, months))?;
872    let dt = dt
873        .checked_sub_signed(interval.duration_as_chrono())
874        .ok_or(EvalError::TimestampOutOfRange)?;
875    Ok(dt.try_into()?)
876}
877
878#[sqlfunc(
879    is_monotone = "(false, false)",
880    is_infix_op = true,
881    sqlname = "-",
882    propagates_nulls = true
883)]
884fn sub_time_interval(time: chrono::NaiveTime, interval: Interval) -> chrono::NaiveTime {
885    let (t, _) = time.overflowing_sub_signed(interval.duration_as_chrono());
886    t
887}
888
889#[sqlfunc(
890    is_monotone = "(true, true)",
891    is_infix_op = true,
892    sqlname = "*",
893    propagates_nulls = true
894)]
895fn mul_int16(a: i16, b: i16) -> Result<i16, EvalError> {
896    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
897}
898
899#[sqlfunc(
900    is_monotone = "(true, true)",
901    is_infix_op = true,
902    sqlname = "*",
903    propagates_nulls = true
904)]
905fn mul_int32(a: i32, b: i32) -> Result<i32, EvalError> {
906    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
907}
908
909#[sqlfunc(
910    is_monotone = "(true, true)",
911    is_infix_op = true,
912    sqlname = "*",
913    propagates_nulls = true
914)]
915fn mul_int64(a: i64, b: i64) -> Result<i64, EvalError> {
916    a.checked_mul(b).ok_or(EvalError::NumericFieldOverflow)
917}
918
919#[sqlfunc(
920    is_monotone = "(true, true)",
921    is_infix_op = true,
922    sqlname = "*",
923    propagates_nulls = true
924)]
925fn mul_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
926    a.checked_mul(b)
927        .ok_or_else(|| EvalError::UInt16OutOfRange(format!("{a} * {b}").into()))
928}
929
930#[sqlfunc(
931    is_monotone = "(true, true)",
932    is_infix_op = true,
933    sqlname = "*",
934    propagates_nulls = true
935)]
936fn mul_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
937    a.checked_mul(b)
938        .ok_or_else(|| EvalError::UInt32OutOfRange(format!("{a} * {b}").into()))
939}
940
941#[sqlfunc(
942    is_monotone = "(true, true)",
943    is_infix_op = true,
944    sqlname = "*",
945    propagates_nulls = true
946)]
947fn mul_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
948    a.checked_mul(b)
949        .ok_or_else(|| EvalError::UInt64OutOfRange(format!("{a} * {b}").into()))
950}
951
952#[sqlfunc(
953    is_monotone = (true, true),
954    is_infix_op = true,
955    sqlname = "*",
956    propagates_nulls = true
957)]
958fn mul_float32(a: f32, b: f32) -> Result<f32, EvalError> {
959    let product = a * b;
960    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
961        Err(EvalError::FloatOverflow)
962    } else if product == 0.0f32 && a != 0.0f32 && b != 0.0f32 {
963        Err(EvalError::FloatUnderflow)
964    } else {
965        Ok(product)
966    }
967}
968
969#[sqlfunc(
970    is_monotone = "(true, true)",
971    is_infix_op = true,
972    sqlname = "*",
973    propagates_nulls = true
974)]
975fn mul_float64(a: f64, b: f64) -> Result<f64, EvalError> {
976    let product = a * b;
977    if product.is_infinite() && !a.is_infinite() && !b.is_infinite() {
978        Err(EvalError::FloatOverflow)
979    } else if product == 0.0f64 && a != 0.0f64 && b != 0.0f64 {
980        Err(EvalError::FloatUnderflow)
981    } else {
982        Ok(product)
983    }
984}
985
986#[sqlfunc(
987    is_monotone = "(true, true)",
988    is_infix_op = true,
989    sqlname = "*",
990    propagates_nulls = true
991)]
992fn mul_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
993    let mut cx = numeric::cx_datum();
994    cx.mul(&mut a, &b);
995    let cx_status = cx.status();
996    if cx_status.overflow() {
997        Err(EvalError::FloatOverflow)
998    } else if cx_status.subnormal() {
999        Err(EvalError::FloatUnderflow)
1000    } else {
1001        numeric::munge_numeric(&mut a).unwrap();
1002        Ok(a)
1003    }
1004}
1005
1006#[sqlfunc(
1007    is_monotone = "(false, false)",
1008    is_infix_op = true,
1009    sqlname = "*",
1010    propagates_nulls = true
1011)]
1012fn mul_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1013    a.checked_mul(b)
1014        .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} * {b}").into()))
1015}
1016
1017#[sqlfunc(
1018    is_monotone = "(true, false)",
1019    is_infix_op = true,
1020    sqlname = "/",
1021    propagates_nulls = true
1022)]
1023fn div_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1024    if b == 0 {
1025        Err(EvalError::DivisionByZero)
1026    } else {
1027        a.checked_div(b)
1028            .ok_or_else(|| EvalError::Int16OutOfRange(format!("{a} / {b}").into()))
1029    }
1030}
1031
1032#[sqlfunc(
1033    is_monotone = "(true, false)",
1034    is_infix_op = true,
1035    sqlname = "/",
1036    propagates_nulls = true
1037)]
1038fn div_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1039    if b == 0 {
1040        Err(EvalError::DivisionByZero)
1041    } else {
1042        a.checked_div(b)
1043            .ok_or_else(|| EvalError::Int32OutOfRange(format!("{a} / {b}").into()))
1044    }
1045}
1046
1047#[sqlfunc(
1048    is_monotone = "(true, false)",
1049    is_infix_op = true,
1050    sqlname = "/",
1051    propagates_nulls = true
1052)]
1053fn div_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1054    if b == 0 {
1055        Err(EvalError::DivisionByZero)
1056    } else {
1057        a.checked_div(b)
1058            .ok_or_else(|| EvalError::Int64OutOfRange(format!("{a} / {b}").into()))
1059    }
1060}
1061
1062#[sqlfunc(
1063    is_monotone = "(true, false)",
1064    is_infix_op = true,
1065    sqlname = "/",
1066    propagates_nulls = true
1067)]
1068fn div_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1069    if b == 0 {
1070        Err(EvalError::DivisionByZero)
1071    } else {
1072        Ok(a / b)
1073    }
1074}
1075
1076#[sqlfunc(
1077    is_monotone = "(true, false)",
1078    is_infix_op = true,
1079    sqlname = "/",
1080    propagates_nulls = true
1081)]
1082fn div_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1083    if b == 0 {
1084        Err(EvalError::DivisionByZero)
1085    } else {
1086        Ok(a / b)
1087    }
1088}
1089
1090#[sqlfunc(
1091    is_monotone = "(true, false)",
1092    is_infix_op = true,
1093    sqlname = "/",
1094    propagates_nulls = true
1095)]
1096fn div_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1097    if b == 0 {
1098        Err(EvalError::DivisionByZero)
1099    } else {
1100        Ok(a / b)
1101    }
1102}
1103
1104#[sqlfunc(
1105    is_monotone = "(true, false)",
1106    is_infix_op = true,
1107    sqlname = "/",
1108    propagates_nulls = true
1109)]
1110fn div_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1111    if b == 0.0f32 && !a.is_nan() {
1112        Err(EvalError::DivisionByZero)
1113    } else {
1114        let quotient = a / b;
1115        if quotient.is_infinite() && !a.is_infinite() {
1116            Err(EvalError::FloatOverflow)
1117        } else if quotient == 0.0f32 && a != 0.0f32 && !b.is_infinite() {
1118            Err(EvalError::FloatUnderflow)
1119        } else {
1120            Ok(quotient)
1121        }
1122    }
1123}
1124
1125#[sqlfunc(
1126    is_monotone = "(true, false)",
1127    is_infix_op = true,
1128    sqlname = "/",
1129    propagates_nulls = true
1130)]
1131fn div_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1132    if b == 0.0f64 && !a.is_nan() {
1133        Err(EvalError::DivisionByZero)
1134    } else {
1135        let quotient = a / b;
1136        if quotient.is_infinite() && !a.is_infinite() {
1137            Err(EvalError::FloatOverflow)
1138        } else if quotient == 0.0f64 && a != 0.0f64 && !b.is_infinite() {
1139            Err(EvalError::FloatUnderflow)
1140        } else {
1141            Ok(quotient)
1142        }
1143    }
1144}
1145
1146#[sqlfunc(
1147    is_monotone = "(true, false)",
1148    is_infix_op = true,
1149    sqlname = "/",
1150    propagates_nulls = true
1151)]
1152fn div_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1153    let mut cx = numeric::cx_datum();
1154
1155    cx.div(&mut a, &b);
1156    let cx_status = cx.status();
1157
1158    // checking the status for division by zero errors is insufficient because
1159    // the underlying library treats 0/0 as undefined and not division by zero.
1160    if b.is_zero() {
1161        Err(EvalError::DivisionByZero)
1162    } else if cx_status.overflow() {
1163        Err(EvalError::FloatOverflow)
1164    } else if cx_status.subnormal() {
1165        Err(EvalError::FloatUnderflow)
1166    } else {
1167        numeric::munge_numeric(&mut a).unwrap();
1168        Ok(a)
1169    }
1170}
1171
1172#[sqlfunc(
1173    is_monotone = "(false, false)",
1174    is_infix_op = true,
1175    sqlname = "/",
1176    propagates_nulls = true
1177)]
1178fn div_interval(a: Interval, b: f64) -> Result<Interval, EvalError> {
1179    if b == 0.0 {
1180        Err(EvalError::DivisionByZero)
1181    } else {
1182        a.checked_div(b)
1183            .ok_or_else(|| EvalError::IntervalOutOfRange(format!("{a} / {b}").into()))
1184    }
1185}
1186
1187#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1188fn mod_int16(a: i16, b: i16) -> Result<i16, EvalError> {
1189    if b == 0 {
1190        Err(EvalError::DivisionByZero)
1191    } else {
1192        Ok(a.checked_rem(b).unwrap_or(0))
1193    }
1194}
1195
1196#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1197fn mod_int32(a: i32, b: i32) -> Result<i32, EvalError> {
1198    if b == 0 {
1199        Err(EvalError::DivisionByZero)
1200    } else {
1201        Ok(a.checked_rem(b).unwrap_or(0))
1202    }
1203}
1204
1205#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1206fn mod_int64(a: i64, b: i64) -> Result<i64, EvalError> {
1207    if b == 0 {
1208        Err(EvalError::DivisionByZero)
1209    } else {
1210        Ok(a.checked_rem(b).unwrap_or(0))
1211    }
1212}
1213
1214#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1215fn mod_uint16(a: u16, b: u16) -> Result<u16, EvalError> {
1216    if b == 0 {
1217        Err(EvalError::DivisionByZero)
1218    } else {
1219        Ok(a % b)
1220    }
1221}
1222
1223#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1224fn mod_uint32(a: u32, b: u32) -> Result<u32, EvalError> {
1225    if b == 0 {
1226        Err(EvalError::DivisionByZero)
1227    } else {
1228        Ok(a % b)
1229    }
1230}
1231
1232#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1233fn mod_uint64(a: u64, b: u64) -> Result<u64, EvalError> {
1234    if b == 0 {
1235        Err(EvalError::DivisionByZero)
1236    } else {
1237        Ok(a % b)
1238    }
1239}
1240
1241#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1242fn mod_float32(a: f32, b: f32) -> Result<f32, EvalError> {
1243    if b == 0.0 {
1244        Err(EvalError::DivisionByZero)
1245    } else {
1246        Ok(a % b)
1247    }
1248}
1249
1250#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1251fn mod_float64(a: f64, b: f64) -> Result<f64, EvalError> {
1252    if b == 0.0 {
1253        Err(EvalError::DivisionByZero)
1254    } else {
1255        Ok(a % b)
1256    }
1257}
1258
1259#[sqlfunc(is_infix_op = true, sqlname = "%", propagates_nulls = true)]
1260fn mod_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1261    if b.is_zero() {
1262        return Err(EvalError::DivisionByZero);
1263    }
1264    let mut cx = numeric::cx_datum();
1265    // Postgres does _not_ use IEEE 754-style remainder
1266    cx.rem(&mut a, &b);
1267    numeric::munge_numeric(&mut a).unwrap();
1268    Ok(a)
1269}
1270
1271fn neg_interval_inner(a: Interval) -> Result<Interval, EvalError> {
1272    a.checked_neg()
1273        .ok_or_else(|| EvalError::IntervalOutOfRange(a.to_string().into()))
1274}
1275
1276fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
1277    if val.is_negative() {
1278        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
1279    }
1280    if val.is_zero() {
1281        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
1282    }
1283    Ok(())
1284}
1285
1286#[sqlfunc(sqlname = "log", propagates_nulls = true)]
1287fn log_base_numeric(mut a: Numeric, mut b: Numeric) -> Result<Numeric, EvalError> {
1288    log_guard_numeric(&a, "log")?;
1289    log_guard_numeric(&b, "log")?;
1290    let mut cx = numeric::cx_datum();
1291    cx.ln(&mut a);
1292    cx.ln(&mut b);
1293    cx.div(&mut b, &a);
1294    if a.is_zero() {
1295        Err(EvalError::DivisionByZero)
1296    } else {
1297        // This division can result in slightly wrong answers due to the
1298        // limitation of dividing irrational numbers. To correct that, see if
1299        // rounding off the value from its `numeric::NUMERIC_DATUM_MAX_PRECISION
1300        // - 1`th position results in an integral value.
1301        cx.set_precision(usize::from(numeric::NUMERIC_DATUM_MAX_PRECISION - 1))
1302            .expect("reducing precision below max always succeeds");
1303        let mut integral_check = b.clone();
1304
1305        // `reduce` rounds to the context's final digit when the number of
1306        // digits in its argument exceeds its precision. We've contrived that to
1307        // happen by shrinking the context's precision by 1.
1308        cx.reduce(&mut integral_check);
1309
1310        // Reduced integral values always have a non-negative exponent.
1311        let mut b = if integral_check.exponent() >= 0 {
1312            // We believe our result should have been an integral
1313            integral_check
1314        } else {
1315            b
1316        };
1317
1318        numeric::munge_numeric(&mut b).unwrap();
1319        Ok(b)
1320    }
1321}
1322
1323#[sqlfunc(propagates_nulls = true)]
1324fn power(a: f64, b: f64) -> Result<f64, EvalError> {
1325    if a == 0.0 && b.is_sign_negative() {
1326        return Err(EvalError::Undefined(
1327            "zero raised to a negative power".into(),
1328        ));
1329    }
1330    if a.is_sign_negative() && b.fract() != 0.0 {
1331        // Equivalent to PG error:
1332        // > a negative number raised to a non-integer power yields a complex result
1333        return Err(EvalError::ComplexOutOfRange("pow".into()));
1334    }
1335    let res = a.powf(b);
1336    if res.is_infinite() {
1337        return Err(EvalError::FloatOverflow);
1338    }
1339    if res == 0.0 && a != 0.0 {
1340        return Err(EvalError::FloatUnderflow);
1341    }
1342    Ok(res)
1343}
1344
1345#[sqlfunc(propagates_nulls = true)]
1346fn uuid_generate_v5(a: uuid::Uuid, b: &str) -> uuid::Uuid {
1347    uuid::Uuid::new_v5(&a, b.as_bytes())
1348}
1349
1350#[sqlfunc(output_type = "Numeric", propagates_nulls = true)]
1351fn power_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
1352    if a.is_zero() {
1353        if b.is_zero() {
1354            return Ok(Numeric::from(1));
1355        }
1356        if b.is_negative() {
1357            return Err(EvalError::Undefined(
1358                "zero raised to a negative power".into(),
1359            ));
1360        }
1361    }
1362    if a.is_negative() && b.exponent() < 0 {
1363        // Equivalent to PG error:
1364        // > a negative number raised to a non-integer power yields a complex result
1365        return Err(EvalError::ComplexOutOfRange("pow".into()));
1366    }
1367    let mut cx = numeric::cx_datum();
1368    cx.pow(&mut a, &b);
1369    let cx_status = cx.status();
1370    if cx_status.overflow() || (cx_status.invalid_operation() && !b.is_negative()) {
1371        Err(EvalError::FloatOverflow)
1372    } else if cx_status.subnormal() || cx_status.invalid_operation() {
1373        Err(EvalError::FloatUnderflow)
1374    } else {
1375        numeric::munge_numeric(&mut a).unwrap();
1376        Ok(a)
1377    }
1378}
1379
1380#[sqlfunc(propagates_nulls = true)]
1381fn get_bit(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1382    let err = EvalError::IndexOutOfRange {
1383        provided: index,
1384        valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap() - 1,
1385    };
1386
1387    let index = usize::try_from(index).map_err(|_| err.clone())?;
1388
1389    let byte_index = index / 8;
1390    let bit_index = index % 8;
1391
1392    let i = bytes
1393        .get(byte_index)
1394        .map(|b| (*b >> bit_index) & 1)
1395        .ok_or(err)?;
1396    assert!(i == 0 || i == 1);
1397    Ok(i32::from(i))
1398}
1399
1400#[sqlfunc(propagates_nulls = true)]
1401fn get_byte(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
1402    let err = EvalError::IndexOutOfRange {
1403        provided: index,
1404        valid_end: i32::try_from(bytes.len()).unwrap() - 1,
1405    };
1406    let i: &u8 = bytes
1407        .get(usize::try_from(index).map_err(|_| err.clone())?)
1408        .ok_or(err)?;
1409    Ok(i32::from(*i))
1410}
1411
1412#[sqlfunc(sqlname = "constant_time_compare_bytes", propagates_nulls = true)]
1413pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
1414    bool::from(a.ct_eq(b))
1415}
1416
1417#[sqlfunc(sqlname = "constant_time_compare_strings", propagates_nulls = true)]
1418pub fn constant_time_eq_string(a: &str, b: &str) -> bool {
1419    bool::from(a.as_bytes().ct_eq(b.as_bytes()))
1420}
1421
1422#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1423fn range_contains_i32<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1424    a.contains_elem(&b)
1425}
1426
1427#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1428fn range_contains_i64<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1429    a.contains_elem(&elem)
1430}
1431
1432#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1433fn range_contains_date<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1434    a.contains_elem(&elem)
1435}
1436
1437#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1438fn range_contains_numeric<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1439    a.contains_elem(&elem)
1440}
1441
1442#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1443fn range_contains_timestamp<'a>(
1444    a: Range<Datum<'a>>,
1445    elem: CheckedTimestamp<NaiveDateTime>,
1446) -> bool {
1447    a.contains_elem(&elem)
1448}
1449
1450#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1451fn range_contains_timestamp_tz<'a>(
1452    a: Range<Datum<'a>>,
1453    elem: CheckedTimestamp<DateTime<Utc>>,
1454) -> bool {
1455    a.contains_elem(&elem)
1456}
1457
1458#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1459fn range_contains_i32_rev<'a>(a: Range<Datum<'a>>, b: i32) -> bool {
1460    a.contains_elem(&b)
1461}
1462
1463#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1464fn range_contains_i64_rev<'a>(a: Range<Datum<'a>>, elem: i64) -> bool {
1465    a.contains_elem(&elem)
1466}
1467
1468#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1469fn range_contains_date_rev<'a>(a: Range<Datum<'a>>, elem: Date) -> bool {
1470    a.contains_elem(&elem)
1471}
1472
1473#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1474fn range_contains_numeric_rev<'a>(a: Range<Datum<'a>>, elem: OrderedDecimal<Numeric>) -> bool {
1475    a.contains_elem(&elem)
1476}
1477
1478#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1479fn range_contains_timestamp_rev<'a>(
1480    a: Range<Datum<'a>>,
1481    elem: CheckedTimestamp<NaiveDateTime>,
1482) -> bool {
1483    a.contains_elem(&elem)
1484}
1485
1486#[sqlfunc(is_infix_op = true, sqlname = "<@", propagates_nulls = true)]
1487fn range_contains_timestamp_tz_rev<'a>(
1488    a: Range<Datum<'a>>,
1489    elem: CheckedTimestamp<DateTime<Utc>>,
1490) -> bool {
1491    a.contains_elem(&elem)
1492}
1493
1494/// Macro to define binary function for various range operations.
1495/// Parameters:
1496/// 1. Unique binary function symbol.
1497/// 2. Range function symbol.
1498/// 3. SQL name for the function.
1499macro_rules! range_fn {
1500    ($fn:expr, $range_fn:expr, $sqlname:expr) => {
1501        paste::paste! {
1502
1503            #[sqlfunc(
1504                output_type = "bool",
1505                is_infix_op = true,
1506                sqlname = $sqlname,
1507                propagates_nulls = true
1508            )]
1509            fn [< range_ $fn >]<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a>
1510            {
1511                if a.is_null() || b.is_null() { return Datum::Null }
1512                let l = a.unwrap_range();
1513                let r = b.unwrap_range();
1514                Datum::from(Range::<Datum<'a>>::$range_fn(&l, &r))
1515            }
1516        }
1517    };
1518}
1519
1520// RangeContainsRange is either @> or <@ depending on the order of the arguments.
1521// It doesn't influence the result, but it does influence the display string.
1522range_fn!(contains_range, contains_range, "@>");
1523range_fn!(contains_range_rev, contains_range, "<@");
1524range_fn!(overlaps, overlaps, "&&");
1525range_fn!(after, after, ">>");
1526range_fn!(before, before, "<<");
1527range_fn!(overleft, overleft, "&<");
1528range_fn!(overright, overright, "&>");
1529range_fn!(adjacent, adjacent, "-|-");
1530
1531#[sqlfunc(
1532    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
1533    is_infix_op = true,
1534    sqlname = "+",
1535    propagates_nulls = true,
1536    introduces_nulls = false
1537)]
1538fn range_union<'a>(
1539    l: Range<Datum<'a>>,
1540    r: Range<Datum<'a>>,
1541    temp_storage: &'a RowArena,
1542) -> Result<Datum<'a>, EvalError> {
1543    l.union(&r)?.into_result(temp_storage)
1544}
1545
1546#[sqlfunc(
1547    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
1548    is_infix_op = true,
1549    sqlname = "*",
1550    propagates_nulls = true,
1551    introduces_nulls = false
1552)]
1553fn range_intersection<'a>(
1554    l: Range<Datum<'a>>,
1555    r: Range<Datum<'a>>,
1556    temp_storage: &'a RowArena,
1557) -> Result<Datum<'a>, EvalError> {
1558    l.intersection(&r).into_result(temp_storage)
1559}
1560
1561#[sqlfunc(
1562    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
1563    is_infix_op = true,
1564    sqlname = "-",
1565    propagates_nulls = true,
1566    introduces_nulls = false
1567)]
1568fn range_difference<'a>(
1569    l: Range<Datum<'a>>,
1570    r: Range<Datum<'a>>,
1571    temp_storage: &'a RowArena,
1572) -> Result<Datum<'a>, EvalError> {
1573    l.difference(&r)?.into_result(temp_storage)
1574}
1575
1576#[sqlfunc(is_infix_op = true, sqlname = "=", negate = "Some(NotEq.into())")]
1577fn eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1578    // SQL equality demands that if either input is null, then the result should be null. However,
1579    // we don't need to handle this case here; it is handled when `BinaryFunc::eval` checks
1580    // `propagates_nulls`.
1581    a == b
1582}
1583
1584#[sqlfunc(is_infix_op = true, sqlname = "!=", negate = "Some(Eq.into())")]
1585fn not_eq<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1586    a != b
1587}
1588
1589#[sqlfunc(
1590    is_monotone = "(true, true)",
1591    is_infix_op = true,
1592    sqlname = "<",
1593    negate = "Some(Gte.into())"
1594)]
1595fn lt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1596    a < b
1597}
1598
1599#[sqlfunc(
1600    is_monotone = "(true, true)",
1601    is_infix_op = true,
1602    sqlname = "<=",
1603    negate = "Some(Gt.into())"
1604)]
1605fn lte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1606    a <= b
1607}
1608
1609#[sqlfunc(
1610    is_monotone = "(true, true)",
1611    is_infix_op = true,
1612    sqlname = ">",
1613    negate = "Some(Lte.into())"
1614)]
1615fn gt<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1616    a > b
1617}
1618
1619#[sqlfunc(
1620    is_monotone = "(true, true)",
1621    is_infix_op = true,
1622    sqlname = ">=",
1623    negate = "Some(Lt.into())"
1624)]
1625fn gte<'a>(a: ExcludeNull<Datum<'a>>, b: ExcludeNull<Datum<'a>>) -> bool {
1626    a >= b
1627}
1628
1629#[sqlfunc(sqlname = "tocharts", propagates_nulls = true)]
1630fn to_char_timestamp_format(ts: CheckedTimestamp<chrono::NaiveDateTime>, format: &str) -> String {
1631    let fmt = DateTimeFormat::compile(format);
1632    fmt.render(&*ts)
1633}
1634
1635#[sqlfunc(sqlname = "tochartstz", propagates_nulls = true)]
1636fn to_char_timestamp_tz_format(
1637    ts: CheckedTimestamp<chrono::DateTime<Utc>>,
1638    format: &str,
1639) -> String {
1640    let fmt = DateTimeFormat::compile(format);
1641    fmt.render(&*ts)
1642}
1643
1644#[sqlfunc(sqlname = "->", is_infix_op = true)]
1645fn jsonb_get_int64<'a>(a: JsonbRef<'a>, i: i64) -> Option<JsonbRef<'a>> {
1646    match a.into_datum() {
1647        Datum::List(list) => {
1648            let i = if i >= 0 {
1649                usize::cast_from(i.unsigned_abs())
1650            } else {
1651                // index backwards from the end
1652                let i = usize::cast_from(i.unsigned_abs());
1653                (list.iter().count()).wrapping_sub(i)
1654            };
1655            let v = list.iter().nth(i)?;
1656            // `v` should be valid jsonb because it came from a jsonb list, but we don't
1657            // panic on mismatch to avoid bringing down the whole system on corrupt data.
1658            // Instead, we'll return None.
1659            JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1660        }
1661        Datum::Map(_) => None,
1662        _ => {
1663            // I have no idea why postgres does this, but we're stuck with it
1664            (i == 0 || i == -1).then_some(a)
1665        }
1666    }
1667}
1668
1669#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1670fn jsonb_get_int64_stringify<'a>(
1671    a: JsonbRef<'a>,
1672    i: i64,
1673    temp_storage: &'a RowArena,
1674) -> Option<&'a str> {
1675    let json = jsonb_get_int64(a, i)?;
1676    jsonb_stringify(json.into_datum(), temp_storage)
1677}
1678
1679#[sqlfunc(sqlname = "->", is_infix_op = true)]
1680fn jsonb_get_string<'a>(a: JsonbRef<'a>, k: &str) -> Option<JsonbRef<'a>> {
1681    let dict = DatumMap::try_from_result(Ok::<_, ()>(a.into_datum())).ok()?;
1682    let v = dict.iter().find(|(k2, _v)| k == *k2).map(|(_k, v)| v)?;
1683    JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()
1684}
1685
1686#[sqlfunc(sqlname = "->>", is_infix_op = true)]
1687fn jsonb_get_string_stringify<'a>(
1688    a: JsonbRef<'a>,
1689    k: &str,
1690    temp_storage: &'a RowArena,
1691) -> Option<&'a str> {
1692    let v = jsonb_get_string(a, k)?;
1693    jsonb_stringify(v.into_datum(), temp_storage)
1694}
1695
1696#[sqlfunc(sqlname = "#>", is_infix_op = true)]
1697fn jsonb_get_path<'a>(mut json: JsonbRef<'a>, b: Array<'a>) -> Option<JsonbRef<'a>> {
1698    let path = b.elements();
1699    for key in path.iter() {
1700        let key = match key {
1701            Datum::String(s) => s,
1702            Datum::Null => return None,
1703            _ => unreachable!("keys in jsonb_get_path known to be strings"),
1704        };
1705        let v = match json.into_datum() {
1706            Datum::Map(map) => map.iter().find(|(k, _)| key == *k).map(|(_k, v)| v),
1707            Datum::List(list) => {
1708                let i = strconv::parse_int64(key).ok()?;
1709                let i = if i >= 0 {
1710                    usize::cast_from(i.unsigned_abs())
1711                } else {
1712                    // index backwards from the end
1713                    let i = usize::cast_from(i.unsigned_abs());
1714                    (list.iter().count()).wrapping_sub(i)
1715                };
1716                list.iter().nth(i)
1717            }
1718            _ => return None,
1719        }?;
1720        json = JsonbRef::try_from_result(Ok::<_, ()>(v)).ok()?;
1721    }
1722    Some(json)
1723}
1724
1725#[sqlfunc(sqlname = "#>>", is_infix_op = true)]
1726fn jsonb_get_path_stringify<'a>(
1727    a: JsonbRef<'a>,
1728    b: Array<'a>,
1729    temp_storage: &'a RowArena,
1730) -> Option<&'a str> {
1731    let json = jsonb_get_path(a, b)?;
1732    jsonb_stringify(json.into_datum(), temp_storage)
1733}
1734
1735#[sqlfunc(is_infix_op = true, sqlname = "?", propagates_nulls = true)]
1736fn jsonb_contains_string<'a>(a: Datum<'a>, k: &str) -> bool {
1737    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1738    match a {
1739        Datum::List(list) => list.iter().any(|k2| Datum::from(k) == k2),
1740        Datum::Map(dict) => dict.iter().any(|(k2, _v)| k == k2),
1741        Datum::String(string) => string == k,
1742        _ => false,
1743    }
1744}
1745
1746#[sqlfunc(is_infix_op = true, sqlname = "?", propagates_nulls = true)]
1747// Map keys are always text.
1748fn map_contains_key<'a>(map: DatumMap<'a>, k: &str) -> bool {
1749    map.iter().any(|(k2, _v)| k == k2)
1750}
1751
1752#[sqlfunc(is_infix_op = true, sqlname = "?&")]
1753fn map_contains_all_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1754    keys.elements()
1755        .iter()
1756        .all(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1757}
1758
1759#[sqlfunc(is_infix_op = true, sqlname = "?|", propagates_nulls = true)]
1760fn map_contains_any_keys<'a>(map: DatumMap<'a>, keys: Array<'a>) -> bool {
1761    keys.elements()
1762        .iter()
1763        .any(|key| !key.is_null() && map.iter().any(|(k, _v)| k == key.unwrap_str()))
1764}
1765
1766#[sqlfunc(is_infix_op = true, sqlname = "@>", propagates_nulls = true)]
1767fn map_contains_map<'a>(map_a: DatumMap<'a>, b: DatumMap<'a>) -> bool {
1768    b.iter().all(|(b_key, b_val)| {
1769        map_a
1770            .iter()
1771            .any(|(a_key, a_val)| (a_key == b_key) && (a_val == b_val))
1772    })
1773}
1774
1775#[sqlfunc(
1776    output_type_expr = "input_types[0].scalar_type.unwrap_map_value_type().clone().nullable(true)",
1777    is_infix_op = true,
1778    sqlname = "->",
1779    propagates_nulls = true,
1780    introduces_nulls = true
1781)]
1782fn map_get_value<'a>(a: DatumMap<'a>, target_key: &str) -> Datum<'a> {
1783    match a.iter().find(|(key, _v)| target_key == *key) {
1784        Some((_k, v)) => v,
1785        None => Datum::Null,
1786    }
1787}
1788
1789#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1790fn list_contains_list<'a>(a: ExcludeNull<DatumList<'a>>, b: ExcludeNull<DatumList<'a>>) -> bool {
1791    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
1792    if b.iter().contains(&Datum::Null) {
1793        false
1794    } else {
1795        b.iter()
1796            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
1797    }
1798}
1799
1800#[sqlfunc(is_infix_op = true, sqlname = "<@")]
1801fn list_contains_list_rev<'a>(
1802    a: ExcludeNull<DatumList<'a>>,
1803    b: ExcludeNull<DatumList<'a>>,
1804) -> bool {
1805    list_contains_list(b, a)
1806}
1807
1808// TODO(jamii) nested loops are possibly not the fastest way to do this
1809#[sqlfunc(is_infix_op = true, sqlname = "@>")]
1810fn jsonb_contains_jsonb<'a>(a: JsonbRef<'a>, b: JsonbRef<'a>) -> bool {
1811    // https://www.postgresql.org/docs/current/datatype-json.html#JSON-CONTAINMENT
1812    fn contains(a: Datum, b: Datum, at_top_level: bool) -> bool {
1813        match (a, b) {
1814            (Datum::JsonNull, Datum::JsonNull) => true,
1815            (Datum::False, Datum::False) => true,
1816            (Datum::True, Datum::True) => true,
1817            (Datum::Numeric(a), Datum::Numeric(b)) => a == b,
1818            (Datum::String(a), Datum::String(b)) => a == b,
1819            (Datum::List(a), Datum::List(b)) => b
1820                .iter()
1821                .all(|b_elem| a.iter().any(|a_elem| contains(a_elem, b_elem, false))),
1822            (Datum::Map(a), Datum::Map(b)) => b.iter().all(|(b_key, b_val)| {
1823                a.iter()
1824                    .any(|(a_key, a_val)| (a_key == b_key) && contains(a_val, b_val, false))
1825            }),
1826
1827            // fun special case
1828            (Datum::List(a), b) => {
1829                at_top_level && a.iter().any(|a_elem| contains(a_elem, b, false))
1830            }
1831
1832            _ => false,
1833        }
1834    }
1835    contains(a.into_datum(), b.into_datum(), true)
1836}
1837
1838#[sqlfunc(is_infix_op = true, sqlname = "||")]
1839fn jsonb_concat<'a>(
1840    a: JsonbRef<'a>,
1841    b: JsonbRef<'a>,
1842    temp_storage: &'a RowArena,
1843) -> Option<JsonbRef<'a>> {
1844    let res = match (a.into_datum(), b.into_datum()) {
1845        (Datum::Map(dict_a), Datum::Map(dict_b)) => {
1846            let mut pairs = dict_b.iter().chain(dict_a.iter()).collect::<Vec<_>>();
1847            // stable sort, so if keys collide dedup prefers dict_b
1848            pairs.sort_by(|(k1, _v1), (k2, _v2)| k1.cmp(k2));
1849            pairs.dedup_by(|(k1, _v1), (k2, _v2)| k1 == k2);
1850            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1851        }
1852        (Datum::List(list_a), Datum::List(list_b)) => {
1853            let elems = list_a.iter().chain(list_b.iter());
1854            temp_storage.make_datum(|packer| packer.push_list(elems))
1855        }
1856        (Datum::List(list_a), b) => {
1857            let elems = list_a.iter().chain(Some(b));
1858            temp_storage.make_datum(|packer| packer.push_list(elems))
1859        }
1860        (a, Datum::List(list_b)) => {
1861            let elems = Some(a).into_iter().chain(list_b.iter());
1862            temp_storage.make_datum(|packer| packer.push_list(elems))
1863        }
1864        _ => return None,
1865    };
1866    Some(JsonbRef::from_datum(res))
1867}
1868
1869#[sqlfunc(
1870    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1871    is_infix_op = true,
1872    sqlname = "-",
1873    propagates_nulls = true,
1874    introduces_nulls = true
1875)]
1876fn jsonb_delete_int64<'a>(a: Datum<'a>, i: i64, temp_storage: &'a RowArena) -> Datum<'a> {
1877    match a {
1878        Datum::List(list) => {
1879            let i = if i >= 0 {
1880                usize::cast_from(i.unsigned_abs())
1881            } else {
1882                // index backwards from the end
1883                let i = usize::cast_from(i.unsigned_abs());
1884                (list.iter().count()).wrapping_sub(i)
1885            };
1886            let elems = list
1887                .iter()
1888                .enumerate()
1889                .filter(|(i2, _e)| i != *i2)
1890                .map(|(_, e)| e);
1891            temp_storage.make_datum(|packer| packer.push_list(elems))
1892        }
1893        _ => Datum::Null,
1894    }
1895}
1896
1897#[sqlfunc(
1898    output_type_expr = "SqlScalarType::Jsonb.nullable(true)",
1899    is_infix_op = true,
1900    sqlname = "-",
1901    propagates_nulls = true,
1902    introduces_nulls = true
1903)]
1904fn jsonb_delete_string<'a>(a: Datum<'a>, k: &str, temp_storage: &'a RowArena) -> Datum<'a> {
1905    match a {
1906        Datum::List(list) => {
1907            let elems = list.iter().filter(|e| Datum::from(k) != *e);
1908            temp_storage.make_datum(|packer| packer.push_list(elems))
1909        }
1910        Datum::Map(dict) => {
1911            let pairs = dict.iter().filter(|(k2, _v)| k != *k2);
1912            temp_storage.make_datum(|packer| packer.push_dict(pairs))
1913        }
1914        _ => Datum::Null,
1915    }
1916}
1917
1918#[sqlfunc(
1919    sqlname = "extractiv",
1920    propagates_nulls = true,
1921    introduces_nulls = false
1922)]
1923fn date_part_interval_numeric(units: &str, b: Interval) -> Result<Numeric, EvalError> {
1924    match units.parse() {
1925        Ok(units) => Ok(date_part_interval_inner::<Numeric>(units, b)?),
1926        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1927    }
1928}
1929
1930#[sqlfunc(
1931    sqlname = "date_partiv",
1932    propagates_nulls = true,
1933    introduces_nulls = false
1934)]
1935fn date_part_interval_f64(units: &str, b: Interval) -> Result<f64, EvalError> {
1936    match units.parse() {
1937        Ok(units) => Ok(date_part_interval_inner::<f64>(units, b)?),
1938        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1939    }
1940}
1941
1942#[sqlfunc(
1943    sqlname = "extractt",
1944    propagates_nulls = true,
1945    introduces_nulls = false
1946)]
1947fn date_part_time_numeric(units: &str, b: chrono::NaiveTime) -> Result<Numeric, EvalError> {
1948    match units.parse() {
1949        Ok(units) => Ok(date_part_time_inner::<Numeric>(units, b)?),
1950        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1951    }
1952}
1953
1954#[sqlfunc(
1955    sqlname = "date_partt",
1956    propagates_nulls = true,
1957    introduces_nulls = false
1958)]
1959fn date_part_time_f64(units: &str, b: chrono::NaiveTime) -> Result<f64, EvalError> {
1960    match units.parse() {
1961        Ok(units) => Ok(date_part_time_inner::<f64>(units, b)?),
1962        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1963    }
1964}
1965
1966#[sqlfunc(sqlname = "extractts", propagates_nulls = true)]
1967fn date_part_timestamp_timestamp_numeric(
1968    units: &str,
1969    ts: CheckedTimestamp<NaiveDateTime>,
1970) -> Result<Numeric, EvalError> {
1971    match units.parse() {
1972        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1973        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1974    }
1975}
1976
1977#[sqlfunc(sqlname = "extracttstz", propagates_nulls = true)]
1978fn date_part_timestamp_timestamp_tz_numeric(
1979    units: &str,
1980    ts: CheckedTimestamp<DateTime<Utc>>,
1981) -> Result<Numeric, EvalError> {
1982    match units.parse() {
1983        Ok(units) => Ok(date_part_timestamp_inner::<_, Numeric>(units, &*ts)?),
1984        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1985    }
1986}
1987
1988#[sqlfunc(sqlname = "date_partts", propagates_nulls = true)]
1989fn date_part_timestamp_timestamp_f64(
1990    units: &str,
1991    ts: CheckedTimestamp<NaiveDateTime>,
1992) -> Result<f64, EvalError> {
1993    match units.parse() {
1994        Ok(units) => date_part_timestamp_inner(units, &*ts),
1995        Err(_) => Err(EvalError::UnknownUnits(units.into())),
1996    }
1997}
1998
1999#[sqlfunc(sqlname = "date_parttstz", propagates_nulls = true)]
2000fn date_part_timestamp_timestamp_tz_f64(
2001    units: &str,
2002    ts: CheckedTimestamp<DateTime<Utc>>,
2003) -> Result<f64, EvalError> {
2004    match units.parse() {
2005        Ok(units) => date_part_timestamp_inner(units, &*ts),
2006        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2007    }
2008}
2009
2010#[sqlfunc(sqlname = "extractd", propagates_nulls = true)]
2011fn extract_date_units(units: &str, b: Date) -> Result<Numeric, EvalError> {
2012    match units.parse() {
2013        Ok(units) => Ok(extract_date_inner(units, b.into())?),
2014        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2015    }
2016}
2017
2018pub fn date_bin<T>(
2019    stride: Interval,
2020    source: CheckedTimestamp<T>,
2021    origin: CheckedTimestamp<T>,
2022) -> Result<CheckedTimestamp<T>, EvalError>
2023where
2024    T: TimestampLike,
2025{
2026    if stride.months != 0 {
2027        return Err(EvalError::DateBinOutOfRange(
2028            "timestamps cannot be binned into intervals containing months or years".into(),
2029        ));
2030    }
2031
2032    let stride_ns = match stride.duration_as_chrono().num_nanoseconds() {
2033        Some(ns) if ns <= 0 => Err(EvalError::DateBinOutOfRange(
2034            "stride must be greater than zero".into(),
2035        )),
2036        Some(ns) => Ok(ns),
2037        None => Err(EvalError::DateBinOutOfRange(
2038            format!("stride cannot exceed {}/{} nanoseconds", i64::MAX, i64::MIN,).into(),
2039        )),
2040    }?;
2041
2042    // Make sure the returned timestamp is at the start of the bin, even if the
2043    // origin is in the future. We do this here because `T` is not `Copy` and
2044    // gets moved by its subtraction operation.
2045    let sub_stride = origin > source;
2046
2047    let tm_diff = (source - origin.clone()).num_nanoseconds().ok_or_else(|| {
2048        EvalError::DateBinOutOfRange(
2049            "source and origin must not differ more than 2^63 nanoseconds".into(),
2050        )
2051    })?;
2052
2053    let mut tm_delta = tm_diff - tm_diff % stride_ns;
2054
2055    if sub_stride {
2056        tm_delta -= stride_ns;
2057    }
2058
2059    let res = origin
2060        .checked_add_signed(Duration::nanoseconds(tm_delta))
2061        .ok_or(EvalError::TimestampOutOfRange)?;
2062    Ok(CheckedTimestamp::from_timestamplike(res)?)
2063}
2064
2065#[sqlfunc(is_monotone = "(true, true)", sqlname = "bin_unix_epoch_timestamp")]
2066fn date_bin_timestamp(
2067    stride: Interval,
2068    source: CheckedTimestamp<NaiveDateTime>,
2069) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2070    let origin =
2071        CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap().naive_utc())
2072            .expect("must fit");
2073    date_bin(stride, source, origin)
2074}
2075
2076#[sqlfunc(is_monotone = "(true, true)", sqlname = "bin_unix_epoch_timestamptz")]
2077fn date_bin_timestamp_tz(
2078    stride: Interval,
2079    source: CheckedTimestamp<DateTime<Utc>>,
2080) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2081    let origin = CheckedTimestamp::from_timestamplike(DateTime::from_timestamp(0, 0).unwrap())
2082        .expect("must fit");
2083    date_bin(stride, source, origin)
2084}
2085
2086#[sqlfunc(sqlname = "date_truncts", propagates_nulls = true)]
2087fn date_trunc_units_timestamp(
2088    units: &str,
2089    ts: CheckedTimestamp<NaiveDateTime>,
2090) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2091    match units.parse() {
2092        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2093        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2094    }
2095}
2096
2097#[sqlfunc(sqlname = "date_trunctstz", propagates_nulls = true)]
2098fn date_trunc_units_timestamp_tz(
2099    units: &str,
2100    ts: CheckedTimestamp<DateTime<Utc>>,
2101) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2102    match units.parse() {
2103        Ok(units) => Ok(date_trunc_inner(units, &*ts)?.try_into()?),
2104        Err(_) => Err(EvalError::UnknownUnits(units.into())),
2105    }
2106}
2107
2108#[sqlfunc(sqlname = "date_trunciv", propagates_nulls = true)]
2109fn date_trunc_interval(units: &str, mut interval: Interval) -> Result<Interval, EvalError> {
2110    let dtf = units
2111        .parse()
2112        .map_err(|_| EvalError::UnknownUnits(units.into()))?;
2113
2114    interval
2115        .truncate_low_fields(dtf, Some(0), RoundBehavior::Truncate)
2116        .expect(
2117            "truncate_low_fields should not fail with max_precision 0 and RoundBehavior::Truncate",
2118        );
2119    Ok(interval)
2120}
2121
2122/// Parses a named timezone like `EST` or `America/New_York`, or a fixed-offset timezone like `-05:00`.
2123///
2124/// The interpretation of fixed offsets depend on whether the POSIX or ISO 8601 standard is being
2125/// used.
2126pub(crate) fn parse_timezone(tz: &str, spec: TimezoneSpec) -> Result<Timezone, EvalError> {
2127    Timezone::parse(tz, spec).map_err(|_| EvalError::InvalidTimezone(tz.into()))
2128}
2129
2130/// Converts the time datum `b`, which is assumed to be in UTC, to the timezone that the interval datum `a` is assumed
2131/// to represent. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2132/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2133#[sqlfunc(sqlname = "timezoneit")]
2134fn timezone_interval_time_binary(
2135    interval: Interval,
2136    time: chrono::NaiveTime,
2137) -> Result<chrono::NaiveTime, EvalError> {
2138    if interval.months != 0 {
2139        Err(EvalError::InvalidTimezoneInterval)
2140    } else {
2141        Ok(time.overflowing_add_signed(interval.duration_as_chrono()).0)
2142    }
2143}
2144
2145/// Converts the timestamp datum `b`, which is assumed to be in the time of the timezone datum `a` to a timestamptz
2146/// in UTC. The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2147/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2148#[sqlfunc(sqlname = "timezoneits")]
2149fn timezone_interval_timestamp_binary(
2150    interval: Interval,
2151    ts: CheckedTimestamp<NaiveDateTime>,
2152) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
2153    if interval.months != 0 {
2154        Err(EvalError::InvalidTimezoneInterval)
2155    } else {
2156        match ts.checked_sub_signed(interval.duration_as_chrono()) {
2157            Some(sub) => Ok(DateTime::from_naive_utc_and_offset(sub, Utc).try_into()?),
2158            None => Err(EvalError::TimestampOutOfRange),
2159        }
2160    }
2161}
2162
2163/// Converts the UTC timestamptz datum `b`, to the local timestamp of the timezone datum `a`.
2164/// The interval is not allowed to hold months, but there are no limits on the amount of seconds.
2165/// The interval acts like a `chrono::FixedOffset`, without the `-86,400 < x < 86,400` limitation.
2166#[sqlfunc(sqlname = "timezoneitstz")]
2167fn timezone_interval_timestamp_tz_binary(
2168    interval: Interval,
2169    tstz: CheckedTimestamp<DateTime<Utc>>,
2170) -> Result<CheckedTimestamp<NaiveDateTime>, EvalError> {
2171    if interval.months != 0 {
2172        return Err(EvalError::InvalidTimezoneInterval);
2173    }
2174    match tstz
2175        .naive_utc()
2176        .checked_add_signed(interval.duration_as_chrono())
2177    {
2178        Some(dt) => Ok(dt.try_into()?),
2179        None => Err(EvalError::TimestampOutOfRange),
2180    }
2181}
2182
2183#[sqlfunc(
2184    output_type_expr = r#"SqlScalarType::Record {
2185                fields: [
2186                    ("abbrev".into(), SqlScalarType::String.nullable(false)),
2187                    ("base_utc_offset".into(), SqlScalarType::Interval.nullable(false)),
2188                    ("dst_offset".into(), SqlScalarType::Interval.nullable(false)),
2189                ].into(),
2190                custom_id: None,
2191            }.nullable(true)"#,
2192    propagates_nulls = true,
2193    introduces_nulls = false
2194)]
2195fn timezone_offset<'a>(
2196    tz_str: &str,
2197    b: CheckedTimestamp<chrono::DateTime<Utc>>,
2198    temp_storage: &'a RowArena,
2199) -> Result<Datum<'a>, EvalError> {
2200    let tz = match Tz::from_str_insensitive(tz_str) {
2201        Ok(tz) => tz,
2202        Err(_) => return Err(EvalError::InvalidIanaTimezoneId(tz_str.into())),
2203    };
2204    let offset = tz.offset_from_utc_datetime(&b.naive_utc());
2205    Ok(temp_storage.make_datum(|packer| {
2206        packer.push_list_with(|packer| {
2207            packer.push(Datum::from(offset.abbreviation()));
2208            packer.push(Datum::from(offset.base_utc_offset()));
2209            packer.push(Datum::from(offset.dst_offset()));
2210        });
2211    }))
2212}
2213
2214/// Determines if an mz_aclitem contains one of the specified privileges. This will return true if
2215/// any of the listed privileges are contained in the mz_aclitem.
2216#[sqlfunc(
2217    sqlname = "mz_aclitem_contains_privilege",
2218    output_type = "bool",
2219    propagates_nulls = true
2220)]
2221fn mz_acl_item_contains_privilege(
2222    mz_acl_item: MzAclItem,
2223    privileges: &str,
2224) -> Result<bool, EvalError> {
2225    let acl_mode = AclMode::parse_multiple_privileges(privileges)
2226        .map_err(|e: anyhow::Error| EvalError::InvalidPrivileges(e.to_string().into()))?;
2227    let contains = !mz_acl_item.acl_mode.intersection(acl_mode).is_empty();
2228    Ok(contains)
2229}
2230
2231#[sqlfunc]
2232// transliterated from postgres/src/backend/utils/adt/misc.c
2233fn parse_ident<'a>(ident: &'a str, strict: bool) -> Result<ArrayRustType<Cow<'a, str>>, EvalError> {
2234    fn is_ident_start(c: char) -> bool {
2235        matches!(c, 'A'..='Z' | 'a'..='z' | '_' | '\u{80}'..=char::MAX)
2236    }
2237
2238    fn is_ident_cont(c: char) -> bool {
2239        matches!(c, '0'..='9' | '$') || is_ident_start(c)
2240    }
2241
2242    let mut elems = vec![];
2243    let buf = &mut LexBuf::new(ident);
2244
2245    let mut after_dot = false;
2246
2247    buf.take_while(|ch| ch.is_ascii_whitespace());
2248
2249    loop {
2250        let mut missing_ident = true;
2251
2252        let c = buf.next();
2253
2254        if c == Some('"') {
2255            let s = buf.take_while(|ch| !matches!(ch, '"'));
2256
2257            if buf.next() != Some('"') {
2258                return Err(EvalError::InvalidIdentifier {
2259                    ident: ident.into(),
2260                    detail: Some("String has unclosed double quotes.".into()),
2261                });
2262            }
2263            elems.push(Cow::Borrowed(s));
2264            missing_ident = false;
2265        } else if c.map(is_ident_start).unwrap_or(false) {
2266            buf.prev();
2267            let s = buf.take_while(is_ident_cont);
2268            elems.push(Cow::Owned(s.to_ascii_lowercase()));
2269            missing_ident = false;
2270        }
2271
2272        if missing_ident {
2273            if c == Some('.') {
2274                return Err(EvalError::InvalidIdentifier {
2275                    ident: ident.into(),
2276                    detail: Some("No valid identifier before \".\".".into()),
2277                });
2278            } else if after_dot {
2279                return Err(EvalError::InvalidIdentifier {
2280                    ident: ident.into(),
2281                    detail: Some("No valid identifier after \".\".".into()),
2282                });
2283            } else {
2284                return Err(EvalError::InvalidIdentifier {
2285                    ident: ident.into(),
2286                    detail: None,
2287                });
2288            }
2289        }
2290
2291        buf.take_while(|ch| ch.is_ascii_whitespace());
2292
2293        match buf.next() {
2294            Some('.') => {
2295                after_dot = true;
2296
2297                buf.take_while(|ch| ch.is_ascii_whitespace());
2298            }
2299            Some(_) if strict => {
2300                return Err(EvalError::InvalidIdentifier {
2301                    ident: ident.into(),
2302                    detail: None,
2303                });
2304            }
2305            _ => break,
2306        }
2307    }
2308
2309    Ok(elems.into())
2310}
2311
2312fn regexp_split_to_array_re<'a>(
2313    text: &str,
2314    regexp: &Regex,
2315    temp_storage: &'a RowArena,
2316) -> Result<Datum<'a>, EvalError> {
2317    let found = mz_regexp::regexp_split_to_array(text, regexp);
2318    let mut row = Row::default();
2319    let mut packer = row.packer();
2320    packer.try_push_array(
2321        &[ArrayDimension {
2322            lower_bound: 1,
2323            length: found.len(),
2324        }],
2325        found.into_iter().map(Datum::String),
2326    )?;
2327    Ok(temp_storage.push_unary_row(row))
2328}
2329
2330#[sqlfunc(propagates_nulls = true)]
2331fn pretty_sql<'a>(sql: &str, width: i32, temp_storage: &'a RowArena) -> Result<&'a str, EvalError> {
2332    let width =
2333        usize::try_from(width).map_err(|_| EvalError::PrettyError("invalid width".into()))?;
2334    let pretty = pretty_str(
2335        sql,
2336        PrettyConfig {
2337            width,
2338            format_mode: FormatMode::Simple,
2339        },
2340    )
2341    .map_err(|e| EvalError::PrettyError(e.to_string().into()))?;
2342    let pretty = temp_storage.push_string(pretty);
2343    Ok(pretty)
2344}
2345
2346#[sqlfunc(propagates_nulls = true)]
2347fn starts_with(a: &str, b: &str) -> bool {
2348    a.starts_with(b)
2349}
2350
2351#[sqlfunc(
2352    sqlname = "||",
2353    is_infix_op = true,
2354    propagates_nulls = true,
2355    // Text concatenation is monotonic in its second argument, because if I change the
2356    // second argument but don't change the first argument, then we won't find a difference
2357    // in that part of the concatenation result that came from the first argument, so we'll
2358    // find the difference that comes from changing the second argument.
2359    // (It's not monotonic in its first argument, because e.g.,
2360    // 'A' < 'AA' but 'AZ' > 'AAZ'.)
2361    is_monotone = (false, true),
2362)]
2363fn text_concat_binary(a: &str, b: &str) -> Result<String, EvalError> {
2364    if a.len() + b.len() > MAX_STRING_FUNC_RESULT_BYTES {
2365        return Err(EvalError::LengthTooLarge);
2366    }
2367    let mut buf = String::with_capacity(a.len() + b.len());
2368    buf.push_str(a);
2369    buf.push_str(b);
2370    Ok(buf)
2371}
2372
2373#[sqlfunc(propagates_nulls = true, introduces_nulls = false)]
2374fn like_escape<'a>(
2375    pattern: &str,
2376    b: &str,
2377    temp_storage: &'a RowArena,
2378) -> Result<&'a str, EvalError> {
2379    let escape = like_pattern::EscapeBehavior::from_str(b)?;
2380    let normalized = like_pattern::normalize_pattern(pattern, escape)?;
2381    Ok(temp_storage.push_string(normalized))
2382}
2383
2384#[sqlfunc(is_infix_op = true, sqlname = "like")]
2385fn is_like_match_case_sensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2386    like_pattern::compile(pattern, false).map(|needle| needle.is_match(haystack))
2387}
2388
2389#[sqlfunc(is_infix_op = true, sqlname = "ilike")]
2390fn is_like_match_case_insensitive(haystack: &str, pattern: &str) -> Result<bool, EvalError> {
2391    like_pattern::compile(pattern, true).map(|needle| needle.is_match(haystack))
2392}
2393
2394#[sqlfunc(is_infix_op = true, sqlname = "~")]
2395fn is_regexp_match_case_sensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2396    let regex = build_regex(needle, "")?;
2397    Ok(regex.is_match(haystack))
2398}
2399
2400#[sqlfunc(is_infix_op = true, sqlname = "~*")]
2401fn is_regexp_match_case_insensitive(haystack: &str, needle: &str) -> Result<bool, EvalError> {
2402    let regex = build_regex(needle, "i")?;
2403    Ok(regex.is_match(haystack))
2404}
2405
2406fn regexp_match_static<'a>(
2407    haystack: Datum<'a>,
2408    temp_storage: &'a RowArena,
2409    needle: &regex::Regex,
2410) -> Result<Datum<'a>, EvalError> {
2411    let mut row = Row::default();
2412    let mut packer = row.packer();
2413    if needle.captures_len() > 1 {
2414        // The regex contains capture groups, so return an array containing the
2415        // matched text in each capture group, unless the entire match fails.
2416        // Individual capture groups may also be null if that group did not
2417        // participate in the match.
2418        match needle.captures(haystack.unwrap_str()) {
2419            None => packer.push(Datum::Null),
2420            Some(captures) => packer.try_push_array(
2421                &[ArrayDimension {
2422                    lower_bound: 1,
2423                    length: captures.len() - 1,
2424                }],
2425                // Skip the 0th capture group, which is the whole match.
2426                captures.iter().skip(1).map(|mtch| match mtch {
2427                    None => Datum::Null,
2428                    Some(mtch) => Datum::String(mtch.as_str()),
2429                }),
2430            )?,
2431        }
2432    } else {
2433        // The regex contains no capture groups, so return a one-element array
2434        // containing the match, or null if there is no match.
2435        match needle.find(haystack.unwrap_str()) {
2436            None => packer.push(Datum::Null),
2437            Some(mtch) => packer.try_push_array(
2438                &[ArrayDimension {
2439                    lower_bound: 1,
2440                    length: 1,
2441                }],
2442                iter::once(Datum::String(mtch.as_str())),
2443            )?,
2444        };
2445    };
2446    Ok(temp_storage.push_unary_row(row))
2447}
2448
2449/// Sets `limit` based on the presence of 'g' in `flags` for use in `Regex::replacen`,
2450/// and removes 'g' from `flags` if present.
2451pub(crate) fn regexp_replace_parse_flags(flags: &str) -> (usize, Cow<'_, str>) {
2452    // 'g' means to replace all instead of the first. Use a Cow to avoid allocating in the fast
2453    // path. We could switch build_regex to take an iter which would also achieve that.
2454    let (limit, flags) = if flags.contains('g') {
2455        let flags = flags.replace('g', "");
2456        (0, Cow::Owned(flags))
2457    } else {
2458        (1, Cow::Borrowed(flags))
2459    };
2460    (limit, flags)
2461}
2462
2463pub fn build_regex(needle: &str, flags: &str) -> Result<Regex, EvalError> {
2464    let mut case_insensitive = false;
2465    // Note: Postgres accepts it when both flags are present, taking the last one. We do the same.
2466    for f in flags.chars() {
2467        match f {
2468            'i' => {
2469                case_insensitive = true;
2470            }
2471            'c' => {
2472                case_insensitive = false;
2473            }
2474            _ => return Err(EvalError::InvalidRegexFlag(f)),
2475        }
2476    }
2477    Ok(Regex::new(needle, case_insensitive)?)
2478}
2479
2480#[sqlfunc(sqlname = "repeat")]
2481fn repeat_string(string: &str, count: i32) -> Result<String, EvalError> {
2482    let len = usize::try_from(count).unwrap_or(0);
2483    if (len * string.len()) > MAX_STRING_FUNC_RESULT_BYTES {
2484        return Err(EvalError::LengthTooLarge);
2485    }
2486    Ok(string.repeat(len))
2487}
2488
2489/// Constructs a new zero or one dimensional array out of an arbitrary number of
2490/// scalars.
2491///
2492/// If `datums` is empty, constructs a zero-dimensional array. Otherwise,
2493/// constructs a one dimensional array whose lower bound is one and whose length
2494/// is equal to `datums.len()`.
2495fn array_create_scalar<'a>(
2496    datums: &[Datum<'a>],
2497    temp_storage: &'a RowArena,
2498) -> Result<Datum<'a>, EvalError> {
2499    let mut dims = &[ArrayDimension {
2500        lower_bound: 1,
2501        length: datums.len(),
2502    }][..];
2503    if datums.is_empty() {
2504        // Per PostgreSQL, empty arrays are represented with zero dimensions,
2505        // not one dimension of zero length. We write this condition a little
2506        // strangely to satisfy the borrow checker while avoiding an allocation.
2507        dims = &[];
2508    }
2509    let datum = temp_storage.try_make_datum(|packer| packer.try_push_array(dims, datums))?;
2510    Ok(datum)
2511}
2512
2513fn stringify_datum<'a, B>(
2514    buf: &mut B,
2515    d: Datum<'a>,
2516    ty: &SqlScalarType,
2517) -> Result<strconv::Nestable, EvalError>
2518where
2519    B: FormatBuffer,
2520{
2521    use SqlScalarType::*;
2522    match &ty {
2523        AclItem => Ok(strconv::format_acl_item(buf, d.unwrap_acl_item())),
2524        Bool => Ok(strconv::format_bool(buf, d.unwrap_bool())),
2525        Int16 => Ok(strconv::format_int16(buf, d.unwrap_int16())),
2526        Int32 => Ok(strconv::format_int32(buf, d.unwrap_int32())),
2527        Int64 => Ok(strconv::format_int64(buf, d.unwrap_int64())),
2528        UInt16 => Ok(strconv::format_uint16(buf, d.unwrap_uint16())),
2529        UInt32 | Oid | RegClass | RegProc | RegType => {
2530            Ok(strconv::format_uint32(buf, d.unwrap_uint32()))
2531        }
2532        UInt64 => Ok(strconv::format_uint64(buf, d.unwrap_uint64())),
2533        Float32 => Ok(strconv::format_float32(buf, d.unwrap_float32())),
2534        Float64 => Ok(strconv::format_float64(buf, d.unwrap_float64())),
2535        Numeric { .. } => Ok(strconv::format_numeric(buf, &d.unwrap_numeric())),
2536        Date => Ok(strconv::format_date(buf, d.unwrap_date())),
2537        Time => Ok(strconv::format_time(buf, d.unwrap_time())),
2538        Timestamp { .. } => Ok(strconv::format_timestamp(buf, &d.unwrap_timestamp())),
2539        TimestampTz { .. } => Ok(strconv::format_timestamptz(buf, &d.unwrap_timestamptz())),
2540        Interval => Ok(strconv::format_interval(buf, d.unwrap_interval())),
2541        Bytes => Ok(strconv::format_bytes(buf, d.unwrap_bytes())),
2542        String | VarChar { .. } | PgLegacyName => Ok(strconv::format_string(buf, d.unwrap_str())),
2543        Char { length } => Ok(strconv::format_string(
2544            buf,
2545            &mz_repr::adt::char::format_str_pad(d.unwrap_str(), *length),
2546        )),
2547        PgLegacyChar => {
2548            format_pg_legacy_char(buf, d.unwrap_uint8())?;
2549            Ok(strconv::Nestable::MayNeedEscaping)
2550        }
2551        Jsonb => Ok(strconv::format_jsonb(buf, JsonbRef::from_datum(d))),
2552        Uuid => Ok(strconv::format_uuid(buf, d.unwrap_uuid())),
2553        Record { fields, .. } => {
2554            let mut fields = fields.iter();
2555            strconv::format_record(buf, d.unwrap_list(), |buf, d| {
2556                let (_name, ty) = fields.next().unwrap();
2557                if d.is_null() {
2558                    Ok(buf.write_null())
2559                } else {
2560                    stringify_datum(buf.nonnull_buffer(), d, &ty.scalar_type)
2561                }
2562            })
2563        }
2564        Array(elem_type) => strconv::format_array(
2565            buf,
2566            &d.unwrap_array().dims().into_iter().collect::<Vec<_>>(),
2567            d.unwrap_array().elements(),
2568            |buf, d| {
2569                if d.is_null() {
2570                    Ok(buf.write_null())
2571                } else {
2572                    stringify_datum(buf.nonnull_buffer(), d, elem_type)
2573                }
2574            },
2575        ),
2576        List { element_type, .. } => strconv::format_list(buf, d.unwrap_list(), |buf, d| {
2577            if d.is_null() {
2578                Ok(buf.write_null())
2579            } else {
2580                stringify_datum(buf.nonnull_buffer(), d, element_type)
2581            }
2582        }),
2583        Map { value_type, .. } => strconv::format_map(buf, &d.unwrap_map(), |buf, d| {
2584            if d.is_null() {
2585                Ok(buf.write_null())
2586            } else {
2587                stringify_datum(buf.nonnull_buffer(), d, value_type)
2588            }
2589        }),
2590        Int2Vector => strconv::format_legacy_vector(buf, d.unwrap_array().elements(), |buf, d| {
2591            stringify_datum(buf.nonnull_buffer(), d, &SqlScalarType::Int16)
2592        }),
2593        MzTimestamp { .. } => Ok(strconv::format_mz_timestamp(buf, d.unwrap_mz_timestamp())),
2594        Range { element_type } => strconv::format_range(buf, &d.unwrap_range(), |buf, d| match d {
2595            Some(d) => stringify_datum(buf.nonnull_buffer(), *d, element_type),
2596            None => Ok::<_, EvalError>(buf.write_null()),
2597        }),
2598        MzAclItem => Ok(strconv::format_mz_acl_item(buf, d.unwrap_mz_acl_item())),
2599    }
2600}
2601
2602#[sqlfunc]
2603fn position(substring: &str, string: &str) -> Result<i32, EvalError> {
2604    let char_index = string.find(substring);
2605
2606    if let Some(char_index) = char_index {
2607        // find the index in char space
2608        let string_prefix = &string[0..char_index];
2609
2610        let num_prefix_chars = string_prefix.chars().count();
2611        let num_prefix_chars = i32::try_from(num_prefix_chars)
2612            .map_err(|_| EvalError::Int32OutOfRange(num_prefix_chars.to_string().into()))?;
2613
2614        Ok(num_prefix_chars + 1)
2615    } else {
2616        Ok(0)
2617    }
2618}
2619
2620#[sqlfunc]
2621fn strpos(string: &str, substring: &str) -> Result<i32, EvalError> {
2622    position(substring, string)
2623}
2624
2625#[sqlfunc(
2626    propagates_nulls = true,
2627    // `left` is unfortunately not monotonic (at least for negative second arguments),
2628    // because 'aa' < 'z', but `left(_, -1)` makes 'a' > ''.
2629    is_monotone = (false, false)
2630)]
2631fn left<'a>(string: &'a str, b: i32) -> Result<&'a str, EvalError> {
2632    let n = i64::from(b);
2633
2634    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2635
2636    let end_in_bytes = match n.cmp(&0) {
2637        Ordering::Equal => 0,
2638        Ordering::Greater => {
2639            let n = usize::try_from(n).map_err(|_| {
2640                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2641            })?;
2642            // nth from the back
2643            byte_indices.nth(n).unwrap_or(string.len())
2644        }
2645        Ordering::Less => {
2646            let n = usize::try_from(n.abs() - 1).map_err(|_| {
2647                EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2648            })?;
2649            byte_indices.rev().nth(n).unwrap_or(0)
2650        }
2651    };
2652
2653    Ok(&string[..end_in_bytes])
2654}
2655
2656#[sqlfunc(propagates_nulls = true)]
2657fn right<'a>(string: &'a str, n: i32) -> Result<&'a str, EvalError> {
2658    let mut byte_indices = string.char_indices().map(|(i, _)| i);
2659
2660    let start_in_bytes = if n == 0 {
2661        string.len()
2662    } else if n > 0 {
2663        let n = usize::try_from(n - 1).map_err(|_| {
2664            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2665        })?;
2666        // nth from the back
2667        byte_indices.rev().nth(n).unwrap_or(0)
2668    } else if n == i32::MIN {
2669        // this seems strange but Postgres behaves like this
2670        0
2671    } else {
2672        let n = n.abs();
2673        let n = usize::try_from(n).map_err(|_| {
2674            EvalError::InvalidParameterValue(format!("invalid parameter n: {:?}", n).into())
2675        })?;
2676        byte_indices.nth(n).unwrap_or(string.len())
2677    };
2678
2679    Ok(&string[start_in_bytes..])
2680}
2681
2682#[sqlfunc(sqlname = "btrim", propagates_nulls = true)]
2683fn trim<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2684    a.trim_matches(|c| trim_chars.contains(c))
2685}
2686
2687#[sqlfunc(sqlname = "ltrim", propagates_nulls = true)]
2688fn trim_leading<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2689    a.trim_start_matches(|c| trim_chars.contains(c))
2690}
2691
2692#[sqlfunc(sqlname = "rtrim", propagates_nulls = true)]
2693fn trim_trailing<'a>(a: &'a str, trim_chars: &str) -> &'a str {
2694    a.trim_end_matches(|c| trim_chars.contains(c))
2695}
2696
2697#[sqlfunc(
2698    sqlname = "array_length",
2699    propagates_nulls = true,
2700    introduces_nulls = true
2701)]
2702fn array_length<'a>(a: Array<'a>, b: i64) -> Result<Option<i32>, EvalError> {
2703    let i = match usize::try_from(b) {
2704        Ok(0) | Err(_) => return Ok(None),
2705        Ok(n) => n - 1,
2706    };
2707    Ok(match a.dims().into_iter().nth(i) {
2708        None => None,
2709        Some(dim) => Some(
2710            dim.length
2711                .try_into()
2712                .map_err(|_| EvalError::Int32OutOfRange(dim.length.to_string().into()))?,
2713        ),
2714    })
2715}
2716
2717#[sqlfunc(
2718    output_type = "Option<i32>",
2719    is_infix_op = true,
2720    sqlname = "array_lower",
2721    propagates_nulls = true,
2722    introduces_nulls = true
2723)]
2724// TODO(benesch): remove potentially dangerous usage of `as`.
2725#[allow(clippy::as_conversions)]
2726fn array_lower<'a>(a: Array<'a>, i: i64) -> Option<i32> {
2727    if i < 1 {
2728        return None;
2729    }
2730    match a.dims().into_iter().nth(i as usize - 1) {
2731        Some(_) => Some(1),
2732        None => None,
2733    }
2734}
2735
2736#[sqlfunc(
2737    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2738    sqlname = "array_remove",
2739    propagates_nulls = false,
2740    introduces_nulls = false
2741)]
2742fn array_remove<'a>(
2743    arr: Array<'a>,
2744    b: Datum<'a>,
2745    temp_storage: &'a RowArena,
2746) -> Result<Datum<'a>, EvalError> {
2747    // Zero-dimensional arrays are empty by definition
2748    if arr.dims().len() == 0 {
2749        return Ok(Datum::Array(arr));
2750    }
2751
2752    // array_remove only supports one-dimensional arrays
2753    if arr.dims().len() > 1 {
2754        return Err(EvalError::MultidimensionalArrayRemovalNotSupported);
2755    }
2756
2757    let elems: Vec<_> = arr.elements().iter().filter(|v| v != &b).collect();
2758    let mut dims = arr.dims().into_iter().collect::<Vec<_>>();
2759    // This access is safe because `dims` is guaranteed to be non-empty
2760    dims[0] = ArrayDimension {
2761        lower_bound: 1,
2762        length: elems.len(),
2763    };
2764
2765    Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?)
2766}
2767
2768#[sqlfunc(
2769    output_type = "Option<i32>",
2770    is_infix_op = true,
2771    sqlname = "array_upper",
2772    propagates_nulls = true,
2773    introduces_nulls = true
2774)]
2775// TODO(benesch): remove potentially dangerous usage of `as`.
2776#[allow(clippy::as_conversions)]
2777fn array_upper<'a>(a: Array<'a>, i: i64) -> Result<Option<i32>, EvalError> {
2778    if i < 1 {
2779        return Ok(None);
2780    }
2781    a.dims()
2782        .into_iter()
2783        .nth(i as usize - 1)
2784        .map(|dim| {
2785            dim.length
2786                .try_into()
2787                .map_err(|_| EvalError::Int32OutOfRange(dim.length.to_string().into()))
2788        })
2789        .transpose()
2790}
2791
2792#[sqlfunc(
2793    is_infix_op = true,
2794    sqlname = "array_contains",
2795    propagates_nulls = true,
2796    introduces_nulls = false
2797)]
2798fn array_contains<'a>(a: Datum<'a>, array: Array<'a>) -> bool {
2799    array.elements().iter().any(|e| e == a)
2800}
2801
2802#[sqlfunc(is_infix_op = true, sqlname = "@>")]
2803fn array_contains_array<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2804    let a = a.elements();
2805    let b = b.elements();
2806
2807    // NULL is never equal to NULL. If NULL is an element of b, b cannot be contained in a, even if a contains NULL.
2808    if b.iter().contains(&Datum::Null) {
2809        false
2810    } else {
2811        b.iter()
2812            .all(|item_b| a.iter().any(|item_a| item_a == item_b))
2813    }
2814}
2815
2816#[sqlfunc(is_infix_op = true, sqlname = "<@")]
2817fn array_contains_array_rev<'a>(a: Array<'a>, b: Array<'a>) -> bool {
2818    array_contains_array(b, a)
2819}
2820
2821#[sqlfunc(
2822    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2823    is_infix_op = true,
2824    sqlname = "||",
2825    propagates_nulls = false,
2826    introduces_nulls = false
2827)]
2828fn array_array_concat<'a>(
2829    a: Datum<'a>,
2830    b: Datum<'a>,
2831    temp_storage: &'a RowArena,
2832) -> Result<Datum<'a>, EvalError> {
2833    if a.is_null() {
2834        return Ok(b);
2835    } else if b.is_null() {
2836        return Ok(a);
2837    }
2838
2839    let a_array = a.unwrap_array();
2840    let b_array = b.unwrap_array();
2841
2842    let a_dims: Vec<ArrayDimension> = a_array.dims().into_iter().collect();
2843    let b_dims: Vec<ArrayDimension> = b_array.dims().into_iter().collect();
2844
2845    let a_ndims = a_dims.len();
2846    let b_ndims = b_dims.len();
2847
2848    // Per PostgreSQL, if either of the input arrays is zero dimensional,
2849    // the output is the other array, no matter their dimensions.
2850    if a_ndims == 0 {
2851        return Ok(b);
2852    } else if b_ndims == 0 {
2853        return Ok(a);
2854    }
2855
2856    // Postgres supports concatenating arrays of different dimensions,
2857    // as long as one of the arrays has the same type as an element of
2858    // the other array, i.e. `int[2][4] || int[4]` (or `int[4] || int[2][4]`)
2859    // works, because each element of `int[2][4]` is an `int[4]`.
2860    // This check is separate from the one below because Postgres gives a
2861    // specific error message if the number of dimensions differs by more
2862    // than one.
2863    // This cast is safe since MAX_ARRAY_DIMENSIONS is 6
2864    // Can be replaced by .abs_diff once it is stabilized
2865    // TODO(benesch): remove potentially dangerous usage of `as`.
2866    #[allow(clippy::as_conversions)]
2867    if (a_ndims as isize - b_ndims as isize).abs() > 1 {
2868        return Err(EvalError::IncompatibleArrayDimensions {
2869            dims: Some((a_ndims, b_ndims)),
2870        });
2871    }
2872
2873    let mut dims;
2874
2875    // After the checks above, we are certain that:
2876    // - neither array is zero dimensional nor empty
2877    // - both arrays have the same number of dimensions, or differ
2878    //   at most by one.
2879    match a_ndims.cmp(&b_ndims) {
2880        // If both arrays have the same number of dimensions, validate
2881        // that their inner dimensions are the same and concatenate the
2882        // arrays.
2883        Ordering::Equal => {
2884            if &a_dims[1..] != &b_dims[1..] {
2885                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2886            }
2887            dims = vec![ArrayDimension {
2888                lower_bound: a_dims[0].lower_bound,
2889                length: a_dims[0].length + b_dims[0].length,
2890            }];
2891            dims.extend(&a_dims[1..]);
2892        }
2893        // If `a` has less dimensions than `b`, this is an element-array
2894        // concatenation, which requires that `a` has the same dimensions
2895        // as an element of `b`.
2896        Ordering::Less => {
2897            if &a_dims[..] != &b_dims[1..] {
2898                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2899            }
2900            dims = vec![ArrayDimension {
2901                lower_bound: b_dims[0].lower_bound,
2902                // Since `a` is treated as an element of `b`, the length of
2903                // the first dimension of `b` is incremented by one, as `a` is
2904                // non-empty.
2905                length: b_dims[0].length + 1,
2906            }];
2907            dims.extend(a_dims);
2908        }
2909        // If `a` has more dimensions than `b`, this is an array-element
2910        // concatenation, which requires that `b` has the same dimensions
2911        // as an element of `a`.
2912        Ordering::Greater => {
2913            if &a_dims[1..] != &b_dims[..] {
2914                return Err(EvalError::IncompatibleArrayDimensions { dims: None });
2915            }
2916            dims = vec![ArrayDimension {
2917                lower_bound: a_dims[0].lower_bound,
2918                // Since `b` is treated as an element of `a`, the length of
2919                // the first dimension of `a` is incremented by one, as `b`
2920                // is non-empty.
2921                length: a_dims[0].length + 1,
2922            }];
2923            dims.extend(b_dims);
2924        }
2925    }
2926
2927    let elems = a_array.elements().iter().chain(b_array.elements().iter());
2928
2929    Ok(temp_storage.try_make_datum(|packer| packer.try_push_array(&dims, elems))?)
2930}
2931
2932#[sqlfunc(
2933    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2934    is_infix_op = true,
2935    sqlname = "||",
2936    propagates_nulls = false,
2937    introduces_nulls = false
2938)]
2939fn list_list_concat<'a>(a: Datum<'a>, b: Datum<'a>, temp_storage: &'a RowArena) -> Datum<'a> {
2940    if a.is_null() {
2941        return b;
2942    } else if b.is_null() {
2943        return a;
2944    }
2945
2946    let a = a.unwrap_list().iter();
2947    let b = b.unwrap_list().iter();
2948
2949    temp_storage.make_datum(|packer| packer.push_list(a.chain(b)))
2950}
2951
2952#[sqlfunc(
2953    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2954    is_infix_op = true,
2955    sqlname = "||",
2956    propagates_nulls = false,
2957    introduces_nulls = false
2958)]
2959fn list_element_concat<'a>(a: Datum<'a>, b: Datum<'a>, temp_storage: &'a RowArena) -> Datum<'a> {
2960    temp_storage.make_datum(|packer| {
2961        packer.push_list_with(|packer| {
2962            if !a.is_null() {
2963                for elem in a.unwrap_list().iter() {
2964                    packer.push(elem);
2965                }
2966            }
2967            packer.push(b);
2968        })
2969    })
2970}
2971
2972// Note that the output type corresponds to the _second_ parameter's input type.
2973#[sqlfunc(
2974    output_type_expr = "input_types[1].scalar_type.without_modifiers().nullable(true)",
2975    is_infix_op = true,
2976    sqlname = "||",
2977    propagates_nulls = false,
2978    introduces_nulls = false
2979)]
2980fn element_list_concat<'a>(a: Datum<'a>, b: Datum<'a>, temp_storage: &'a RowArena) -> Datum<'a> {
2981    temp_storage.make_datum(|packer| {
2982        packer.push_list_with(|packer| {
2983            packer.push(a);
2984            if !b.is_null() {
2985                for elem in b.unwrap_list().iter() {
2986                    packer.push(elem);
2987                }
2988            }
2989        })
2990    })
2991}
2992
2993#[sqlfunc(
2994    output_type_expr = "input_types[0].scalar_type.without_modifiers().nullable(true)",
2995    sqlname = "list_remove",
2996    propagates_nulls = false,
2997    introduces_nulls = false
2998)]
2999fn list_remove<'a>(a: DatumList<'a>, b: Datum<'a>, temp_storage: &'a RowArena) -> Datum<'a> {
3000    temp_storage.make_datum(|packer| {
3001        packer.push_list_with(|packer| {
3002            for elem in a.iter() {
3003                if elem != b {
3004                    packer.push(elem);
3005                }
3006            }
3007        })
3008    })
3009}
3010
3011#[sqlfunc(
3012    output_type = "Vec<u8>",
3013    sqlname = "digest",
3014    propagates_nulls = true,
3015    introduces_nulls = false
3016)]
3017fn digest_string<'a>(a: &str, b: &str, temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
3018    let to_digest = a.as_bytes();
3019    digest_inner(to_digest, b, temp_storage)
3020}
3021
3022#[sqlfunc(
3023    output_type = "Vec<u8>",
3024    sqlname = "digest",
3025    propagates_nulls = true,
3026    introduces_nulls = false
3027)]
3028fn digest_bytes<'a>(a: &[u8], b: &str, temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
3029    let to_digest = a;
3030    digest_inner(to_digest, b, temp_storage)
3031}
3032
3033fn digest_inner<'a>(
3034    bytes: &[u8],
3035    digest_fn: &str,
3036    temp_storage: &'a RowArena,
3037) -> Result<Datum<'a>, EvalError> {
3038    let bytes = match digest_fn {
3039        "md5" => Md5::digest(bytes).to_vec(),
3040        "sha1" => Sha1::digest(bytes).to_vec(),
3041        "sha224" => Sha224::digest(bytes).to_vec(),
3042        "sha256" => Sha256::digest(bytes).to_vec(),
3043        "sha384" => Sha384::digest(bytes).to_vec(),
3044        "sha512" => Sha512::digest(bytes).to_vec(),
3045        other => return Err(EvalError::InvalidHashAlgorithm(other.into())),
3046    };
3047    Ok(Datum::Bytes(temp_storage.push_bytes(bytes)))
3048}
3049
3050#[sqlfunc(
3051    output_type = "String",
3052    sqlname = "mz_render_typmod",
3053    propagates_nulls = true,
3054    introduces_nulls = false
3055)]
3056fn mz_render_typmod<'a>(
3057    oid: u32,
3058    typmod: i32,
3059    temp_storage: &'a RowArena,
3060) -> Result<Datum<'a>, EvalError> {
3061    let s = match Type::from_oid_and_typmod(oid, typmod) {
3062        Ok(typ) => typ.constraint().display_or("").to_string(),
3063        // Match dubious PostgreSQL behavior of outputting the unmodified
3064        // `typmod` when positive if the type OID/typmod is invalid.
3065        Err(_) if typmod >= 0 => format!("({typmod})"),
3066        Err(_) => "".into(),
3067    };
3068    Ok(Datum::String(temp_storage.push_string(s)))
3069}
3070
3071#[cfg(test)]
3072mod test {
3073    use chrono::prelude::*;
3074    use mz_repr::PropDatum;
3075    use proptest::prelude::*;
3076
3077    use super::*;
3078    use crate::MirScalarExpr;
3079
3080    #[mz_ore::test]
3081    fn add_interval_months() {
3082        let dt = ym(2000, 1);
3083
3084        assert_eq!(add_timestamp_months(&*dt, 0).unwrap(), dt);
3085        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 2));
3086        assert_eq!(add_timestamp_months(&*dt, 12).unwrap(), ym(2001, 1));
3087        assert_eq!(add_timestamp_months(&*dt, 13).unwrap(), ym(2001, 2));
3088        assert_eq!(add_timestamp_months(&*dt, 24).unwrap(), ym(2002, 1));
3089        assert_eq!(add_timestamp_months(&*dt, 30).unwrap(), ym(2002, 7));
3090
3091        // and negatives
3092        assert_eq!(add_timestamp_months(&*dt, -1).unwrap(), ym(1999, 12));
3093        assert_eq!(add_timestamp_months(&*dt, -12).unwrap(), ym(1999, 1));
3094        assert_eq!(add_timestamp_months(&*dt, -13).unwrap(), ym(1998, 12));
3095        assert_eq!(add_timestamp_months(&*dt, -24).unwrap(), ym(1998, 1));
3096        assert_eq!(add_timestamp_months(&*dt, -30).unwrap(), ym(1997, 7));
3097
3098        // and going over a year boundary by less than a year
3099        let dt = ym(1999, 12);
3100        assert_eq!(add_timestamp_months(&*dt, 1).unwrap(), ym(2000, 1));
3101        let end_of_month_dt = NaiveDate::from_ymd_opt(1999, 12, 31)
3102            .unwrap()
3103            .and_hms_opt(9, 9, 9)
3104            .unwrap();
3105        assert_eq!(
3106            // leap year
3107            add_timestamp_months(&end_of_month_dt, 2).unwrap(),
3108            NaiveDate::from_ymd_opt(2000, 2, 29)
3109                .unwrap()
3110                .and_hms_opt(9, 9, 9)
3111                .unwrap()
3112                .try_into()
3113                .unwrap(),
3114        );
3115        assert_eq!(
3116            // not leap year
3117            add_timestamp_months(&end_of_month_dt, 14).unwrap(),
3118            NaiveDate::from_ymd_opt(2001, 2, 28)
3119                .unwrap()
3120                .and_hms_opt(9, 9, 9)
3121                .unwrap()
3122                .try_into()
3123                .unwrap(),
3124        );
3125    }
3126
3127    fn ym(year: i32, month: u32) -> CheckedTimestamp<NaiveDateTime> {
3128        NaiveDate::from_ymd_opt(year, month, 1)
3129            .unwrap()
3130            .and_hms_opt(9, 9, 9)
3131            .unwrap()
3132            .try_into()
3133            .unwrap()
3134    }
3135
3136    #[mz_ore::test]
3137    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
3138    fn test_is_monotone() {
3139        use proptest::prelude::*;
3140
3141        /// Asserts that the function is either monotonically increasing or decreasing over
3142        /// the given sets of arguments.
3143        fn assert_monotone<'a, const N: usize>(
3144            expr: &MirScalarExpr,
3145            arena: &'a RowArena,
3146            datums: &[[Datum<'a>; N]],
3147        ) {
3148            // TODO: assertions for nulls, errors
3149            let Ok(results) = datums
3150                .iter()
3151                .map(|args| expr.eval(args.as_slice(), arena))
3152                .collect::<Result<Vec<_>, _>>()
3153            else {
3154                return;
3155            };
3156
3157            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
3158            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
3159            assert!(
3160                forward || reverse,
3161                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
3162            );
3163        }
3164
3165        fn proptest_binary<'a>(
3166            func: BinaryFunc,
3167            arena: &'a RowArena,
3168            left: impl Strategy<Value = PropDatum>,
3169            right: impl Strategy<Value = PropDatum>,
3170        ) {
3171            let (left_monotone, right_monotone) = func.is_monotone();
3172            let expr = MirScalarExpr::CallBinary {
3173                func,
3174                expr1: Box::new(MirScalarExpr::column(0)),
3175                expr2: Box::new(MirScalarExpr::column(1)),
3176            };
3177            proptest!(|(
3178                mut left in proptest::array::uniform3(left),
3179                mut right in proptest::array::uniform3(right),
3180            )| {
3181                left.sort();
3182                right.sort();
3183                if left_monotone {
3184                    for r in &right {
3185                        let args: Vec<[_; 2]> = left
3186                            .iter()
3187                            .map(|l| [Datum::from(l), Datum::from(r)])
3188                            .collect();
3189                        assert_monotone(&expr, arena, &args);
3190                    }
3191                }
3192                if right_monotone {
3193                    for l in &left {
3194                        let args: Vec<[_; 2]> = right
3195                            .iter()
3196                            .map(|r| [Datum::from(l), Datum::from(r)])
3197                            .collect();
3198                        assert_monotone(&expr, arena, &args);
3199                    }
3200                }
3201            });
3202        }
3203
3204        let interesting_strs: Vec<_> = SqlScalarType::String.interesting_datums().collect();
3205        let str_datums = proptest::strategy::Union::new([
3206            proptest::string::string_regex("[A-Z]{0,10}")
3207                .expect("valid regex")
3208                .prop_map(|s| PropDatum::String(s.to_string()))
3209                .boxed(),
3210            (0..interesting_strs.len())
3211                .prop_map(move |i| {
3212                    let Datum::String(val) = interesting_strs[i] else {
3213                        unreachable!("interesting strings has non-strings")
3214                    };
3215                    PropDatum::String(val.to_string())
3216                })
3217                .boxed(),
3218        ]);
3219
3220        let interesting_i32s: Vec<Datum<'static>> =
3221            SqlScalarType::Int32.interesting_datums().collect();
3222        let i32_datums = proptest::strategy::Union::new([
3223            any::<i32>().prop_map(PropDatum::Int32).boxed(),
3224            (0..interesting_i32s.len())
3225                .prop_map(move |i| {
3226                    let Datum::Int32(val) = interesting_i32s[i] else {
3227                        unreachable!("interesting int32 has non-i32s")
3228                    };
3229                    PropDatum::Int32(val)
3230                })
3231                .boxed(),
3232            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
3233        ]);
3234
3235        let arena = RowArena::new();
3236
3237        // It would be interesting to test all funcs here, but we currently need to hardcode
3238        // the generators for the argument types, which makes this tedious. Choose an interesting
3239        // subset for now.
3240        proptest_binary(
3241            BinaryFunc::AddInt32(AddInt32),
3242            &arena,
3243            &i32_datums,
3244            &i32_datums,
3245        );
3246        proptest_binary(SubInt32.into(), &arena, &i32_datums, &i32_datums);
3247        proptest_binary(MulInt32.into(), &arena, &i32_datums, &i32_datums);
3248        proptest_binary(DivInt32.into(), &arena, &i32_datums, &i32_datums);
3249        proptest_binary(TextConcatBinary.into(), &arena, &str_datums, &str_datums);
3250        proptest_binary(Left.into(), &arena, &str_datums, &i32_datums);
3251    }
3252}