Skip to main content

mz_expr/scalar/func/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Definition of the [`UnaryFunc`] enum and related machinery.
15
16use std::{fmt, str};
17
18use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
19
20use crate::scalar::func::impls::*;
21use crate::{EvalError, MirScalarExpr};
22
23/// A description of an SQL unary function that has the ability to lazy evaluate its arguments
24// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
25pub trait LazyUnaryFunc {
26    fn eval<'a>(
27        &'a self,
28        datums: &[Datum<'a>],
29        temp_storage: &'a RowArena,
30        a: &'a MirScalarExpr,
31    ) -> Result<Datum<'a>, EvalError>;
32
33    /// The output SqlColumnType of this function.
34    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
35
36    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
37        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
38    }
39
40    /// Whether this function will produce NULL on NULL input.
41    fn propagates_nulls(&self) -> bool;
42
43    /// Whether this function will produce NULL on non-NULL input.
44    fn introduces_nulls(&self) -> bool;
45
46    /// Whether this function might error on non-error input.
47    fn could_error(&self) -> bool {
48        // NB: override this for functions that never error.
49        true
50    }
51
52    /// Whether this function preserves uniqueness.
53    ///
54    /// Uniqueness is preserved when `if f(x) = f(y) then x = y` is true. This
55    /// is used by the optimizer when a guarantee can be made that a collection
56    /// with unique items will stay unique when mapped by this function.
57    ///
58    /// Note that error results are not covered: Even with `preserves_uniqueness = true`, it can
59    /// happen that two different inputs produce the same error result. (e.g., in case of a
60    /// narrowing cast)
61    ///
62    /// Functions should conservatively return `false` unless they are certain
63    /// the above property is true.
64    fn preserves_uniqueness(&self) -> bool;
65
66    /// The [inverse] of this function, if it has one and we have determined it.
67    ///
68    /// The optimizer _can_ use this information when selecting indexes, e.g. an
69    /// indexed column has a cast applied to it, by moving the right inverse of
70    /// the cast to another value, we can select the indexed column.
71    ///
72    /// Note that a value of `None` does not imply that the inverse does not
73    /// exist; it could also mean we have not yet invested the energy in
74    /// representing it. For example, in the case of complex casts, such as
75    /// between two list types, we could determine the right inverse, but doing
76    /// so is not immediately necessary as this information is only used by the
77    /// optimizer.
78    ///
79    /// ## Right vs. left vs. inverses
80    /// - Right inverses are when the inverse function preserves uniqueness.
81    ///   These are the functions that the optimizer uses to move casts between
82    ///   expressions.
83    /// - Left inverses are when the function itself preserves uniqueness.
84    /// - Inverses are when a function is both a right and a left inverse (e.g.,
85    ///   bit_not_int64 is both a right and left inverse of itself).
86    ///
87    /// We call this function `inverse` for simplicity's sake; it doesn't always
88    /// correspond to the mathematical notion of "inverse." However, in
89    /// conjunction with checks to `preserves_uniqueness` you can determine
90    /// which type of inverse we return.
91    ///
92    /// [inverse]: https://en.wikipedia.org/wiki/Inverse_function
93    fn inverse(&self) -> Option<crate::UnaryFunc>;
94
95    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
96    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
97    /// determine the range of possible outputs just by mapping the endpoints.
98    ///
99    /// This property describes the behaviour of the function over ranges where the function is defined:
100    /// ie. the argument and the result are non-error datums.
101    fn is_monotone(&self) -> bool;
102
103    /// Returns true if the function does no actual work, but is merely a type-tracking cast.
104    fn is_eliminable_cast(&self) -> bool;
105}
106
107/// A description of an SQL unary function that operates on eagerly evaluated expressions
108pub trait EagerUnaryFunc {
109    type Input<'a>: InputDatumType<'a, EvalError>;
110    type Output<'a>: OutputDatumType<'a, EvalError>;
111
112    fn call<'a>(&self, input: Self::Input<'a>) -> Self::Output<'a>;
113
114    /// The output SqlColumnType of this function
115    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
116
117    /// The output of this function as a representation type.
118    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
119        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
120    }
121
122    /// Whether this function will produce NULL on NULL input
123    fn propagates_nulls(&self) -> bool {
124        // If the input is not nullable then nulls are propagated
125        !Self::Input::<'_>::nullable()
126    }
127
128    /// Whether this function will produce NULL on non-NULL input
129    fn introduces_nulls(&self) -> bool {
130        // If the output is nullable then nulls can be introduced
131        Self::Output::<'_>::nullable()
132    }
133
134    /// Whether this function could produce an error
135    fn could_error(&self) -> bool {
136        Self::Output::<'_>::fallible()
137    }
138
139    /// Whether this function preserves uniqueness
140    fn preserves_uniqueness(&self) -> bool {
141        false
142    }
143
144    fn inverse(&self) -> Option<crate::UnaryFunc> {
145        None
146    }
147
148    fn is_monotone(&self) -> bool {
149        false
150    }
151
152    fn is_eliminable_cast(&self) -> bool {
153        false
154    }
155}
156
157impl<T: EagerUnaryFunc> LazyUnaryFunc for T {
158    fn eval<'a>(
159        &'a self,
160        datums: &[Datum<'a>],
161        temp_storage: &'a RowArena,
162        a: &'a MirScalarExpr,
163    ) -> Result<Datum<'a>, EvalError> {
164        match T::Input::<'_>::try_from_result(a.eval(datums, temp_storage)) {
165            // If we can convert to the input type then we call the function
166            Ok(input) => self.call(input).into_result(temp_storage),
167            // If we can't and we got a non-null datum something went wrong in the planner
168            Err(Ok(datum)) if !datum.is_null() => {
169                Err(EvalError::Internal("invalid input type".into()))
170            }
171            // Otherwise we just propagate NULLs and errors
172            Err(res) => res,
173        }
174    }
175
176    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
177        self.output_sql_type(input_type)
178    }
179
180    fn propagates_nulls(&self) -> bool {
181        self.propagates_nulls()
182    }
183
184    fn introduces_nulls(&self) -> bool {
185        self.introduces_nulls()
186    }
187
188    fn could_error(&self) -> bool {
189        self.could_error()
190    }
191
192    fn preserves_uniqueness(&self) -> bool {
193        self.preserves_uniqueness()
194    }
195
196    fn inverse(&self) -> Option<crate::UnaryFunc> {
197        self.inverse()
198    }
199
200    fn is_monotone(&self) -> bool {
201        self.is_monotone()
202    }
203
204    fn is_eliminable_cast(&self) -> bool {
205        self.is_eliminable_cast()
206    }
207}
208
209derive_unary!(
210    Not,
211    IsNull,
212    IsTrue,
213    IsFalse,
214    BitNotInt16,
215    BitNotInt32,
216    BitNotInt64,
217    BitNotUint16,
218    BitNotUint32,
219    BitNotUint64,
220    NegInt16,
221    NegInt32,
222    NegInt64,
223    NegFloat32,
224    NegFloat64,
225    NegNumeric,
226    NegInterval,
227    SqrtFloat64,
228    SqrtNumeric,
229    CbrtFloat64,
230    AbsInt16,
231    AbsInt32,
232    AbsInt64,
233    AbsFloat32,
234    AbsFloat64,
235    AbsNumeric,
236    CastBoolToString,
237    CastBoolToStringNonstandard,
238    CastBoolToInt32,
239    CastBoolToInt64,
240    CastInt16ToFloat32,
241    CastInt16ToFloat64,
242    CastInt16ToInt32,
243    CastInt16ToInt64,
244    CastInt16ToUint16,
245    CastInt16ToUint32,
246    CastInt16ToUint64,
247    CastInt16ToString,
248    CastInt2VectorToArray,
249    CastInt32ToBool,
250    CastInt32ToFloat32,
251    CastInt32ToFloat64,
252    CastInt32ToOid,
253    CastInt32ToPgLegacyChar,
254    CastInt32ToInt16,
255    CastInt32ToInt64,
256    CastInt32ToUint16,
257    CastInt32ToUint32,
258    CastInt32ToUint64,
259    CastInt32ToString,
260    CastOidToInt32,
261    CastOidToInt64,
262    CastOidToString,
263    CastOidToRegClass,
264    CastRegClassToOid,
265    CastOidToRegProc,
266    CastRegProcToOid,
267    CastOidToRegType,
268    CastRegTypeToOid,
269    CastInt64ToInt16,
270    CastInt64ToInt32,
271    CastInt64ToUint16,
272    CastInt64ToUint32,
273    CastInt64ToUint64,
274    CastInt16ToNumeric,
275    CastInt32ToNumeric,
276    CastInt64ToBool,
277    CastInt64ToNumeric,
278    CastInt64ToFloat32,
279    CastInt64ToFloat64,
280    CastInt64ToOid,
281    CastInt64ToString,
282    CastUint16ToUint32,
283    CastUint16ToUint64,
284    CastUint16ToInt16,
285    CastUint16ToInt32,
286    CastUint16ToInt64,
287    CastUint16ToNumeric,
288    CastUint16ToFloat32,
289    CastUint16ToFloat64,
290    CastUint16ToString,
291    CastUint32ToUint16,
292    CastUint32ToUint64,
293    CastUint32ToInt16,
294    CastUint32ToInt32,
295    CastUint32ToInt64,
296    CastUint32ToNumeric,
297    CastUint32ToFloat32,
298    CastUint32ToFloat64,
299    CastUint32ToString,
300    CastUint64ToUint16,
301    CastUint64ToUint32,
302    CastUint64ToInt16,
303    CastUint64ToInt32,
304    CastUint64ToInt64,
305    CastUint64ToNumeric,
306    CastUint64ToFloat32,
307    CastUint64ToFloat64,
308    CastUint64ToString,
309    CastFloat32ToInt16,
310    CastFloat32ToInt32,
311    CastFloat32ToInt64,
312    CastFloat32ToUint16,
313    CastFloat32ToUint32,
314    CastFloat32ToUint64,
315    CastFloat32ToFloat64,
316    CastFloat32ToString,
317    CastFloat32ToNumeric,
318    CastFloat64ToNumeric,
319    CastFloat64ToInt16,
320    CastFloat64ToInt32,
321    CastFloat64ToInt64,
322    CastFloat64ToUint16,
323    CastFloat64ToUint32,
324    CastFloat64ToUint64,
325    CastFloat64ToFloat32,
326    CastFloat64ToString,
327    CastNumericToFloat32,
328    CastNumericToFloat64,
329    CastNumericToInt16,
330    CastNumericToInt32,
331    CastNumericToInt64,
332    CastNumericToUint16,
333    CastNumericToUint32,
334    CastNumericToUint64,
335    CastNumericToString,
336    CastMzTimestampToString,
337    CastMzTimestampToTimestamp,
338    CastMzTimestampToTimestampTz,
339    CastStringToMzTimestamp,
340    CastUint64ToMzTimestamp,
341    CastUint32ToMzTimestamp,
342    CastInt64ToMzTimestamp,
343    CastInt32ToMzTimestamp,
344    CastNumericToMzTimestamp,
345    CastTimestampToMzTimestamp,
346    CastTimestampTzToMzTimestamp,
347    CastDateToMzTimestamp,
348    CastStringToBool,
349    CastStringToPgLegacyChar,
350    CastStringToPgLegacyName,
351    CastStringToBytes,
352    CastStringToInt16,
353    CastStringToInt32,
354    CastStringToInt64,
355    CastStringToUint16,
356    CastStringToUint32,
357    CastStringToUint64,
358    CastStringToInt2Vector,
359    CastStringToOid,
360    CastStringToFloat32,
361    CastStringToFloat64,
362    CastStringToDate,
363    CastStringToArray,
364    CastStringToList,
365    CastStringToMap,
366    CastStringToRange,
367    CastStringToTime,
368    CastStringToTimestamp,
369    CastStringToTimestampTz,
370    CastStringToInterval,
371    CastStringToNumeric,
372    CastStringToUuid,
373    CastStringToChar,
374    PadChar,
375    CastStringToVarChar,
376    CastCharToString,
377    CastVarCharToString,
378    CastDateToTimestamp,
379    CastDateToTimestampTz,
380    CastDateToString,
381    CastTimeToInterval,
382    CastTimeToString,
383    CastIntervalToString,
384    CastIntervalToTime,
385    CastTimestampToDate,
386    AdjustTimestampPrecision,
387    CastTimestampToTimestampTz,
388    CastTimestampToString,
389    CastTimestampToTime,
390    CastTimestampTzToDate,
391    CastTimestampTzToTimestamp,
392    AdjustTimestampTzPrecision,
393    CastTimestampTzToString,
394    CastTimestampTzToTime,
395    CastPgLegacyCharToString,
396    CastPgLegacyCharToChar,
397    CastPgLegacyCharToVarChar,
398    CastPgLegacyCharToInt32,
399    CastBytesToString,
400    CastStringToJsonb,
401    CastJsonbToString,
402    CastJsonbableToJsonb,
403    CastJsonbToInt16,
404    CastJsonbToInt32,
405    CastJsonbToInt64,
406    CastJsonbToFloat32,
407    CastJsonbToFloat64,
408    CastJsonbToNumeric,
409    CastJsonbToBool,
410    CastUuidToString,
411    CastRecordToString,
412    CastRecord1ToRecord2,
413    CastArrayToArray,
414    CastArrayToJsonb,
415    CastArrayToString,
416    CastListToString,
417    CastListToJsonb,
418    CastList1ToList2,
419    CastArrayToListOneDim,
420    CastMapToString,
421    CastInt2VectorToString,
422    CastRangeToString,
423    CeilFloat32,
424    CeilFloat64,
425    CeilNumeric,
426    FloorFloat32,
427    FloorFloat64,
428    FloorNumeric,
429    Ascii,
430    BitCountBytes,
431    BitLengthBytes,
432    BitLengthString,
433    ByteLengthBytes,
434    ByteLengthString,
435    CharLength,
436    Chr,
437    IsLikeMatch,
438    IsRegexpMatch,
439    RegexpMatch,
440    ExtractInterval,
441    ExtractTime,
442    ExtractTimestamp,
443    ExtractTimestampTz,
444    ExtractDate,
445    DatePartInterval,
446    DatePartTime,
447    DatePartTimestamp,
448    DatePartTimestampTz,
449    DateTruncTimestamp,
450    DateTruncTimestampTz,
451    TimezoneTimestamp,
452    TimezoneTimestampTz,
453    TimezoneTime,
454    ToTimestamp,
455    ToCharTimestamp,
456    ToCharTimestampTz,
457    JustifyDays,
458    JustifyHours,
459    JustifyInterval,
460    JsonbArrayLength,
461    JsonbTypeof,
462    JsonbStripNulls,
463    JsonbPretty,
464    ParseCatalogId,
465    ParseCatalogPrivileges,
466    RoundFloat32,
467    RoundFloat64,
468    RoundNumeric,
469    TruncFloat32,
470    TruncFloat64,
471    TruncNumeric,
472    TrimWhitespace,
473    TrimLeadingWhitespace,
474    TrimTrailingWhitespace,
475    Initcap,
476    RecordGet,
477    ListLength,
478    MapLength,
479    MapBuildFromRecordList,
480    Upper,
481    Lower,
482    Cos,
483    Acos,
484    Cosh,
485    Acosh,
486    Sin,
487    Asin,
488    Sinh,
489    Asinh,
490    Tan,
491    Atan,
492    Tanh,
493    Atanh,
494    Cot,
495    Degrees,
496    Radians,
497    Log10,
498    Log10Numeric,
499    Ln,
500    LnNumeric,
501    Exp,
502    ExpNumeric,
503    Sleep,
504    Panic,
505    AdjustNumericScale,
506    PgColumnSize,
507    MzRowSize,
508    MzTypeName,
509    StepMzTimestamp,
510    RangeLower,
511    RangeUpper,
512    RangeEmpty,
513    RangeLowerInc,
514    RangeUpperInc,
515    RangeLowerInf,
516    RangeUpperInf,
517    MzAclItemGrantor,
518    MzAclItemGrantee,
519    MzAclItemPrivileges,
520    MzFormatPrivileges,
521    MzValidatePrivileges,
522    MzValidateRolePrivilege,
523    AclItemGrantor,
524    AclItemGrantee,
525    AclItemPrivileges,
526    QuoteIdent,
527    TryParseMonotonicIso8601Timestamp,
528    RegexpSplitToArray,
529    PgSizePretty,
530    Crc32Bytes,
531    Crc32String,
532    KafkaMurmur2Bytes,
533    KafkaMurmur2String,
534    SeahashBytes,
535    SeahashString,
536    Reverse
537);
538
539impl UnaryFunc {
540    /// If the unary_func represents "IS X", return X.
541    ///
542    /// A helper method for being able to print Not(IsX) as IS NOT X.
543    pub fn is(&self) -> Option<&'static str> {
544        match self {
545            UnaryFunc::IsNull(_) => Some("NULL"),
546            UnaryFunc::IsTrue(_) => Some("TRUE"),
547            UnaryFunc::IsFalse(_) => Some("FALSE"),
548            _ => None,
549        }
550    }
551}
552
553#[cfg(test)]
554mod test {
555    use itertools::Itertools;
556    use mz_repr::{PropDatum, SqlScalarType};
557
558    use crate::like_pattern;
559
560    use super::*;
561
562    #[mz_ore::test]
563    fn test_could_error() {
564        for func in [
565            UnaryFunc::IsNull(IsNull),
566            UnaryFunc::CastVarCharToString(CastVarCharToString),
567            UnaryFunc::Not(Not),
568            UnaryFunc::IsLikeMatch(IsLikeMatch(like_pattern::compile("%hi%", false).unwrap())),
569        ] {
570            assert!(!func.could_error())
571        }
572    }
573
574    #[mz_ore::test]
575    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
576    fn test_is_monotone() {
577        use proptest::prelude::*;
578
579        /// Asserts that the function is either monotonically increasing or decreasing over
580        /// the given sets of arguments.
581        fn assert_monotone<'a, const N: usize>(
582            expr: &MirScalarExpr,
583            arena: &'a RowArena,
584            datums: &[[Datum<'a>; N]],
585        ) {
586            // TODO: assertions for nulls, errors
587            let Ok(results) = datums
588                .iter()
589                .map(|args| expr.eval(args.as_slice(), arena))
590                .collect::<Result<Vec<_>, _>>()
591            else {
592                return;
593            };
594
595            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
596            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
597            assert!(
598                forward || reverse,
599                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
600            );
601        }
602
603        fn proptest_unary<'a>(
604            func: UnaryFunc,
605            arena: &'a RowArena,
606            arg: impl Strategy<Value = PropDatum>,
607        ) {
608            let is_monotone = func.is_monotone();
609            let expr = MirScalarExpr::CallUnary {
610                func,
611                expr: Box::new(MirScalarExpr::column(0)),
612            };
613            if is_monotone {
614                proptest!(|(
615                    mut arg in proptest::array::uniform3(arg),
616                )| {
617                    arg.sort();
618                    let args: Vec<_> = arg.iter().map(|a| [Datum::from(a)]).collect();
619                    assert_monotone(&expr, arena, &args);
620                });
621            }
622        }
623
624        let interesting_i32s: Vec<Datum<'static>> =
625            SqlScalarType::Int32.interesting_datums().collect();
626        let i32_datums = proptest::strategy::Union::new([
627            any::<i32>().prop_map(PropDatum::Int32).boxed(),
628            (0..interesting_i32s.len())
629                .prop_map(move |i| {
630                    let Datum::Int32(val) = interesting_i32s[i] else {
631                        unreachable!("interesting int32 has non-i32s")
632                    };
633                    PropDatum::Int32(val)
634                })
635                .boxed(),
636            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
637        ]);
638
639        let arena = RowArena::new();
640
641        // It would be interesting to test all funcs here, but we currently need to hardcode
642        // the generators for the argument types, which makes this tedious. Choose an interesting
643        // subset for now.
644        proptest_unary(
645            UnaryFunc::CastInt32ToNumeric(CastInt32ToNumeric(None)),
646            &arena,
647            &i32_datums,
648        );
649        proptest_unary(
650            UnaryFunc::CastInt32ToUint16(CastInt32ToUint16),
651            &arena,
652            &i32_datums,
653        );
654        proptest_unary(
655            UnaryFunc::CastInt32ToString(CastInt32ToString),
656            &arena,
657            &i32_datums,
658        );
659    }
660}