mz_expr/scalar/func/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Definition of the [`UnaryFunc`] enum and related machinery.
15
16use std::{fmt, str};
17
18use mz_repr::{Datum, DatumType, RowArena, SqlColumnType};
19
20use crate::scalar::func::impls::*;
21use crate::{EvalError, MirScalarExpr};
22
23/// A description of an SQL unary function that has the ability to lazy evaluate its arguments
24// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
25pub trait LazyUnaryFunc {
26    fn eval<'a>(
27        &'a self,
28        datums: &[Datum<'a>],
29        temp_storage: &'a RowArena,
30        a: &'a MirScalarExpr,
31    ) -> Result<Datum<'a>, EvalError>;
32
33    /// The output SqlColumnType of this function.
34    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType;
35
36    /// Whether this function will produce NULL on NULL input.
37    fn propagates_nulls(&self) -> bool;
38
39    /// Whether this function will produce NULL on non-NULL input.
40    fn introduces_nulls(&self) -> bool;
41
42    /// Whether this function might error on non-error input.
43    fn could_error(&self) -> bool {
44        // NB: override this for functions that never error.
45        true
46    }
47
48    /// Whether this function preserves uniqueness.
49    ///
50    /// Uniqueness is preserved when `if f(x) = f(y) then x = y` is true. This
51    /// is used by the optimizer when a guarantee can be made that a collection
52    /// with unique items will stay unique when mapped by this function.
53    ///
54    /// Note that error results are not covered: Even with `preserves_uniqueness = true`, it can
55    /// happen that two different inputs produce the same error result. (e.g., in case of a
56    /// narrowing cast)
57    ///
58    /// Functions should conservatively return `false` unless they are certain
59    /// the above property is true.
60    fn preserves_uniqueness(&self) -> bool;
61
62    /// The [inverse] of this function, if it has one and we have determined it.
63    ///
64    /// The optimizer _can_ use this information when selecting indexes, e.g. an
65    /// indexed column has a cast applied to it, by moving the right inverse of
66    /// the cast to another value, we can select the indexed column.
67    ///
68    /// Note that a value of `None` does not imply that the inverse does not
69    /// exist; it could also mean we have not yet invested the energy in
70    /// representing it. For example, in the case of complex casts, such as
71    /// between two list types, we could determine the right inverse, but doing
72    /// so is not immediately necessary as this information is only used by the
73    /// optimizer.
74    ///
75    /// ## Right vs. left vs. inverses
76    /// - Right inverses are when the inverse function preserves uniqueness.
77    ///   These are the functions that the optimizer uses to move casts between
78    ///   expressions.
79    /// - Left inverses are when the function itself preserves uniqueness.
80    /// - Inverses are when a function is both a right and a left inverse (e.g.,
81    ///   bit_not_int64 is both a right and left inverse of itself).
82    ///
83    /// We call this function `inverse` for simplicity's sake; it doesn't always
84    /// correspond to the mathematical notion of "inverse." However, in
85    /// conjunction with checks to `preserves_uniqueness` you can determine
86    /// which type of inverse we return.
87    ///
88    /// [inverse]: https://en.wikipedia.org/wiki/Inverse_function
89    fn inverse(&self) -> Option<crate::UnaryFunc>;
90
91    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
92    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
93    /// determine the range of possible outputs just by mapping the endpoints.
94    ///
95    /// This property describes the behaviour of the function over ranges where the function is defined:
96    /// ie. the argument and the result are non-error datums.
97    fn is_monotone(&self) -> bool;
98}
99
100/// A description of an SQL unary function that operates on eagerly evaluated expressions
101pub trait EagerUnaryFunc<'a> {
102    type Input: DatumType<'a, EvalError>;
103    type Output: DatumType<'a, EvalError>;
104
105    fn call(&self, input: Self::Input) -> Self::Output;
106
107    /// The output SqlColumnType of this function
108    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType;
109
110    /// Whether this function will produce NULL on NULL input
111    fn propagates_nulls(&self) -> bool {
112        // If the input is not nullable then nulls are propagated
113        !Self::Input::nullable()
114    }
115
116    /// Whether this function will produce NULL on non-NULL input
117    fn introduces_nulls(&self) -> bool {
118        // If the output is nullable then nulls can be introduced
119        Self::Output::nullable()
120    }
121
122    /// Whether this function could produce an error
123    fn could_error(&self) -> bool {
124        Self::Output::fallible()
125    }
126
127    /// Whether this function preserves uniqueness
128    fn preserves_uniqueness(&self) -> bool {
129        false
130    }
131
132    fn inverse(&self) -> Option<crate::UnaryFunc> {
133        None
134    }
135
136    fn is_monotone(&self) -> bool {
137        false
138    }
139}
140
141impl<T: for<'a> EagerUnaryFunc<'a>> LazyUnaryFunc for T {
142    fn eval<'a>(
143        &'a self,
144        datums: &[Datum<'a>],
145        temp_storage: &'a RowArena,
146        a: &'a MirScalarExpr,
147    ) -> Result<Datum<'a>, EvalError> {
148        match T::Input::try_from_result(a.eval(datums, temp_storage)) {
149            // If we can convert to the input type then we call the function
150            Ok(input) => self.call(input).into_result(temp_storage),
151            // If we can't and we got a non-null datum something went wrong in the planner
152            Err(Ok(datum)) if !datum.is_null() => {
153                Err(EvalError::Internal("invalid input type".into()))
154            }
155            // Otherwise we just propagate NULLs and errors
156            Err(res) => res,
157        }
158    }
159
160    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
161        self.output_type(input_type)
162    }
163
164    fn propagates_nulls(&self) -> bool {
165        self.propagates_nulls()
166    }
167
168    fn introduces_nulls(&self) -> bool {
169        self.introduces_nulls()
170    }
171
172    fn could_error(&self) -> bool {
173        self.could_error()
174    }
175
176    fn preserves_uniqueness(&self) -> bool {
177        self.preserves_uniqueness()
178    }
179
180    fn inverse(&self) -> Option<crate::UnaryFunc> {
181        self.inverse()
182    }
183
184    fn is_monotone(&self) -> bool {
185        self.is_monotone()
186    }
187}
188
189derive_unary!(
190    Not,
191    IsNull,
192    IsTrue,
193    IsFalse,
194    BitNotInt16,
195    BitNotInt32,
196    BitNotInt64,
197    BitNotUint16,
198    BitNotUint32,
199    BitNotUint64,
200    NegInt16,
201    NegInt32,
202    NegInt64,
203    NegFloat32,
204    NegFloat64,
205    NegNumeric,
206    NegInterval,
207    SqrtFloat64,
208    SqrtNumeric,
209    CbrtFloat64,
210    AbsInt16,
211    AbsInt32,
212    AbsInt64,
213    AbsFloat32,
214    AbsFloat64,
215    AbsNumeric,
216    CastBoolToString,
217    CastBoolToStringNonstandard,
218    CastBoolToInt32,
219    CastBoolToInt64,
220    CastInt16ToFloat32,
221    CastInt16ToFloat64,
222    CastInt16ToInt32,
223    CastInt16ToInt64,
224    CastInt16ToUint16,
225    CastInt16ToUint32,
226    CastInt16ToUint64,
227    CastInt16ToString,
228    CastInt2VectorToArray,
229    CastInt32ToBool,
230    CastInt32ToFloat32,
231    CastInt32ToFloat64,
232    CastInt32ToOid,
233    CastInt32ToPgLegacyChar,
234    CastInt32ToInt16,
235    CastInt32ToInt64,
236    CastInt32ToUint16,
237    CastInt32ToUint32,
238    CastInt32ToUint64,
239    CastInt32ToString,
240    CastOidToInt32,
241    CastOidToInt64,
242    CastOidToString,
243    CastOidToRegClass,
244    CastRegClassToOid,
245    CastOidToRegProc,
246    CastRegProcToOid,
247    CastOidToRegType,
248    CastRegTypeToOid,
249    CastInt64ToInt16,
250    CastInt64ToInt32,
251    CastInt64ToUint16,
252    CastInt64ToUint32,
253    CastInt64ToUint64,
254    CastInt16ToNumeric,
255    CastInt32ToNumeric,
256    CastInt64ToBool,
257    CastInt64ToNumeric,
258    CastInt64ToFloat32,
259    CastInt64ToFloat64,
260    CastInt64ToOid,
261    CastInt64ToString,
262    CastUint16ToUint32,
263    CastUint16ToUint64,
264    CastUint16ToInt16,
265    CastUint16ToInt32,
266    CastUint16ToInt64,
267    CastUint16ToNumeric,
268    CastUint16ToFloat32,
269    CastUint16ToFloat64,
270    CastUint16ToString,
271    CastUint32ToUint16,
272    CastUint32ToUint64,
273    CastUint32ToInt16,
274    CastUint32ToInt32,
275    CastUint32ToInt64,
276    CastUint32ToNumeric,
277    CastUint32ToFloat32,
278    CastUint32ToFloat64,
279    CastUint32ToString,
280    CastUint64ToUint16,
281    CastUint64ToUint32,
282    CastUint64ToInt16,
283    CastUint64ToInt32,
284    CastUint64ToInt64,
285    CastUint64ToNumeric,
286    CastUint64ToFloat32,
287    CastUint64ToFloat64,
288    CastUint64ToString,
289    CastFloat32ToInt16,
290    CastFloat32ToInt32,
291    CastFloat32ToInt64,
292    CastFloat32ToUint16,
293    CastFloat32ToUint32,
294    CastFloat32ToUint64,
295    CastFloat32ToFloat64,
296    CastFloat32ToString,
297    CastFloat32ToNumeric,
298    CastFloat64ToNumeric,
299    CastFloat64ToInt16,
300    CastFloat64ToInt32,
301    CastFloat64ToInt64,
302    CastFloat64ToUint16,
303    CastFloat64ToUint32,
304    CastFloat64ToUint64,
305    CastFloat64ToFloat32,
306    CastFloat64ToString,
307    CastNumericToFloat32,
308    CastNumericToFloat64,
309    CastNumericToInt16,
310    CastNumericToInt32,
311    CastNumericToInt64,
312    CastNumericToUint16,
313    CastNumericToUint32,
314    CastNumericToUint64,
315    CastNumericToString,
316    CastMzTimestampToString,
317    CastMzTimestampToTimestamp,
318    CastMzTimestampToTimestampTz,
319    CastStringToMzTimestamp,
320    CastUint64ToMzTimestamp,
321    CastUint32ToMzTimestamp,
322    CastInt64ToMzTimestamp,
323    CastInt32ToMzTimestamp,
324    CastNumericToMzTimestamp,
325    CastTimestampToMzTimestamp,
326    CastTimestampTzToMzTimestamp,
327    CastDateToMzTimestamp,
328    CastStringToBool,
329    CastStringToPgLegacyChar,
330    CastStringToPgLegacyName,
331    CastStringToBytes,
332    CastStringToInt16,
333    CastStringToInt32,
334    CastStringToInt64,
335    CastStringToUint16,
336    CastStringToUint32,
337    CastStringToUint64,
338    CastStringToInt2Vector,
339    CastStringToOid,
340    CastStringToFloat32,
341    CastStringToFloat64,
342    CastStringToDate,
343    CastStringToArray,
344    CastStringToList,
345    CastStringToMap,
346    CastStringToRange,
347    CastStringToTime,
348    CastStringToTimestamp,
349    CastStringToTimestampTz,
350    CastStringToInterval,
351    CastStringToNumeric,
352    CastStringToUuid,
353    CastStringToChar,
354    PadChar,
355    CastStringToVarChar,
356    CastCharToString,
357    CastVarCharToString,
358    CastDateToTimestamp,
359    CastDateToTimestampTz,
360    CastDateToString,
361    CastTimeToInterval,
362    CastTimeToString,
363    CastIntervalToString,
364    CastIntervalToTime,
365    CastTimestampToDate,
366    AdjustTimestampPrecision,
367    CastTimestampToTimestampTz,
368    CastTimestampToString,
369    CastTimestampToTime,
370    CastTimestampTzToDate,
371    CastTimestampTzToTimestamp,
372    AdjustTimestampTzPrecision,
373    CastTimestampTzToString,
374    CastTimestampTzToTime,
375    CastPgLegacyCharToString,
376    CastPgLegacyCharToChar,
377    CastPgLegacyCharToVarChar,
378    CastPgLegacyCharToInt32,
379    CastBytesToString,
380    CastStringToJsonb,
381    CastJsonbToString,
382    CastJsonbableToJsonb,
383    CastJsonbToInt16,
384    CastJsonbToInt32,
385    CastJsonbToInt64,
386    CastJsonbToFloat32,
387    CastJsonbToFloat64,
388    CastJsonbToNumeric,
389    CastJsonbToBool,
390    CastUuidToString,
391    CastRecordToString,
392    CastRecord1ToRecord2,
393    CastArrayToArray,
394    CastArrayToJsonb,
395    CastArrayToString,
396    CastListToString,
397    CastListToJsonb,
398    CastList1ToList2,
399    CastArrayToListOneDim,
400    CastMapToString,
401    CastInt2VectorToString,
402    CastRangeToString,
403    CeilFloat32,
404    CeilFloat64,
405    CeilNumeric,
406    FloorFloat32,
407    FloorFloat64,
408    FloorNumeric,
409    Ascii,
410    BitCountBytes,
411    BitLengthBytes,
412    BitLengthString,
413    ByteLengthBytes,
414    ByteLengthString,
415    CharLength,
416    Chr,
417    IsLikeMatch,
418    IsRegexpMatch,
419    RegexpMatch,
420    ExtractInterval,
421    ExtractTime,
422    ExtractTimestamp,
423    ExtractTimestampTz,
424    ExtractDate,
425    DatePartInterval,
426    DatePartTime,
427    DatePartTimestamp,
428    DatePartTimestampTz,
429    DateTruncTimestamp,
430    DateTruncTimestampTz,
431    TimezoneTimestamp,
432    TimezoneTimestampTz,
433    TimezoneTime,
434    ToTimestamp,
435    ToCharTimestamp,
436    ToCharTimestampTz,
437    JustifyDays,
438    JustifyHours,
439    JustifyInterval,
440    JsonbArrayLength,
441    JsonbTypeof,
442    JsonbStripNulls,
443    JsonbPretty,
444    RoundFloat32,
445    RoundFloat64,
446    RoundNumeric,
447    TruncFloat32,
448    TruncFloat64,
449    TruncNumeric,
450    TrimWhitespace,
451    TrimLeadingWhitespace,
452    TrimTrailingWhitespace,
453    Initcap,
454    RecordGet,
455    ListLength,
456    MapLength,
457    MapBuildFromRecordList,
458    Upper,
459    Lower,
460    Cos,
461    Acos,
462    Cosh,
463    Acosh,
464    Sin,
465    Asin,
466    Sinh,
467    Asinh,
468    Tan,
469    Atan,
470    Tanh,
471    Atanh,
472    Cot,
473    Degrees,
474    Radians,
475    Log10,
476    Log10Numeric,
477    Ln,
478    LnNumeric,
479    Exp,
480    ExpNumeric,
481    Sleep,
482    Panic,
483    AdjustNumericScale,
484    PgColumnSize,
485    MzRowSize,
486    MzTypeName,
487    StepMzTimestamp,
488    RangeLower,
489    RangeUpper,
490    RangeEmpty,
491    RangeLowerInc,
492    RangeUpperInc,
493    RangeLowerInf,
494    RangeUpperInf,
495    MzAclItemGrantor,
496    MzAclItemGrantee,
497    MzAclItemPrivileges,
498    MzFormatPrivileges,
499    MzValidatePrivileges,
500    MzValidateRolePrivilege,
501    AclItemGrantor,
502    AclItemGrantee,
503    AclItemPrivileges,
504    QuoteIdent,
505    TryParseMonotonicIso8601Timestamp,
506    RegexpSplitToArray,
507    PgSizePretty,
508    Crc32Bytes,
509    Crc32String,
510    KafkaMurmur2Bytes,
511    KafkaMurmur2String,
512    SeahashBytes,
513    SeahashString,
514    Reverse
515);
516
517impl UnaryFunc {
518    /// If the unary_func represents "IS X", return X.
519    ///
520    /// A helper method for being able to print Not(IsX) as IS NOT X.
521    pub fn is(&self) -> Option<&'static str> {
522        match self {
523            UnaryFunc::IsNull(_) => Some("NULL"),
524            UnaryFunc::IsTrue(_) => Some("TRUE"),
525            UnaryFunc::IsFalse(_) => Some("FALSE"),
526            _ => None,
527        }
528    }
529}
530
531#[cfg(test)]
532mod test {
533    use itertools::Itertools;
534    use mz_repr::{PropDatum, SqlScalarType};
535
536    use crate::like_pattern;
537
538    use super::*;
539
540    #[mz_ore::test]
541    fn test_could_error() {
542        for func in [
543            UnaryFunc::IsNull(IsNull),
544            UnaryFunc::CastVarCharToString(CastVarCharToString),
545            UnaryFunc::Not(Not),
546            UnaryFunc::IsLikeMatch(IsLikeMatch(like_pattern::compile("%hi%", false).unwrap())),
547        ] {
548            assert!(!func.could_error())
549        }
550    }
551
552    #[mz_ore::test]
553    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
554    fn test_is_monotone() {
555        use proptest::prelude::*;
556
557        /// Asserts that the function is either monotonically increasing or decreasing over
558        /// the given sets of arguments.
559        fn assert_monotone<'a, const N: usize>(
560            expr: &MirScalarExpr,
561            arena: &'a RowArena,
562            datums: &[[Datum<'a>; N]],
563        ) {
564            // TODO: assertions for nulls, errors
565            let Ok(results) = datums
566                .iter()
567                .map(|args| expr.eval(args.as_slice(), arena))
568                .collect::<Result<Vec<_>, _>>()
569            else {
570                return;
571            };
572
573            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
574            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
575            assert!(
576                forward || reverse,
577                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
578            );
579        }
580
581        fn proptest_unary<'a>(
582            func: UnaryFunc,
583            arena: &'a RowArena,
584            arg: impl Strategy<Value = PropDatum>,
585        ) {
586            let is_monotone = func.is_monotone();
587            let expr = MirScalarExpr::CallUnary {
588                func,
589                expr: Box::new(MirScalarExpr::column(0)),
590            };
591            if is_monotone {
592                proptest!(|(
593                    mut arg in proptest::array::uniform3(arg),
594                )| {
595                    arg.sort();
596                    let args: Vec<_> = arg.iter().map(|a| [Datum::from(a)]).collect();
597                    assert_monotone(&expr, arena, &args);
598                });
599            }
600        }
601
602        let interesting_i32s: Vec<Datum<'static>> =
603            SqlScalarType::Int32.interesting_datums().collect();
604        let i32_datums = proptest::strategy::Union::new([
605            any::<i32>().prop_map(PropDatum::Int32).boxed(),
606            (0..interesting_i32s.len())
607                .prop_map(move |i| {
608                    let Datum::Int32(val) = interesting_i32s[i] else {
609                        unreachable!("interesting int32 has non-i32s")
610                    };
611                    PropDatum::Int32(val)
612                })
613                .boxed(),
614            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
615        ]);
616
617        let arena = RowArena::new();
618
619        // It would be interesting to test all funcs here, but we currently need to hardcode
620        // the generators for the argument types, which makes this tedious. Choose an interesting
621        // subset for now.
622        proptest_unary(
623            UnaryFunc::CastInt32ToNumeric(CastInt32ToNumeric(None)),
624            &arena,
625            &i32_datums,
626        );
627        proptest_unary(
628            UnaryFunc::CastInt32ToUint16(CastInt32ToUint16),
629            &arena,
630            &i32_datums,
631        );
632        proptest_unary(
633            UnaryFunc::CastInt32ToString(CastInt32ToString),
634            &arena,
635            &i32_datums,
636        );
637    }
638}