Skip to main content

mz_expr/scalar/func/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Definition of the [`UnaryFunc`] enum and related machinery.
15
16use std::{fmt, str};
17
18use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
19
20use crate::Eval;
21use crate::EvalError;
22use crate::scalar::func::RedactSql;
23use crate::scalar::func::impls::*;
24
25/// A description of an SQL unary function that has the ability to lazy evaluate its arguments
26// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
27pub trait LazyUnaryFunc {
28    fn eval<'a>(
29        &'a self,
30        datums: &[Datum<'a>],
31        temp_storage: &'a RowArena,
32        a: &'a impl Eval,
33    ) -> Result<Datum<'a>, EvalError>;
34
35    /// The output SqlColumnType of this function.
36    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
37
38    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
39        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
40    }
41
42    /// Whether this function will produce NULL on NULL input.
43    fn propagates_nulls(&self) -> bool;
44
45    /// Whether this function will produce NULL on non-NULL input.
46    fn introduces_nulls(&self) -> bool;
47
48    /// Whether this function might error on non-error input.
49    fn could_error(&self) -> bool {
50        // NB: override this for functions that never error.
51        true
52    }
53
54    /// Whether this function preserves uniqueness.
55    ///
56    /// Uniqueness is preserved when `if f(x) = f(y) then x = y` is true. This
57    /// is used by the optimizer when a guarantee can be made that a collection
58    /// with unique items will stay unique when mapped by this function.
59    ///
60    /// Note that error results are not covered: Even with `preserves_uniqueness = true`, it can
61    /// happen that two different inputs produce the same error result. (e.g., in case of a
62    /// narrowing cast)
63    ///
64    /// Functions should conservatively return `false` unless they are certain
65    /// the above property is true.
66    fn preserves_uniqueness(&self) -> bool;
67
68    /// The [inverse] of this function, if it has one and we have determined it.
69    ///
70    /// The optimizer _can_ use this information when selecting indexes, e.g. an
71    /// indexed column has a cast applied to it, by moving the right inverse of
72    /// the cast to another value, we can select the indexed column.
73    ///
74    /// Note that a value of `None` does not imply that the inverse does not
75    /// exist; it could also mean we have not yet invested the energy in
76    /// representing it. For example, in the case of complex casts, such as
77    /// between two list types, we could determine the right inverse, but doing
78    /// so is not immediately necessary as this information is only used by the
79    /// optimizer.
80    ///
81    /// ## Right vs. left vs. inverses
82    /// - Right inverses are when the inverse function preserves uniqueness.
83    ///   These are the functions that the optimizer uses to move casts between
84    ///   expressions.
85    /// - Left inverses are when the function itself preserves uniqueness.
86    /// - Inverses are when a function is both a right and a left inverse (e.g.,
87    ///   bit_not_int64 is both a right and left inverse of itself).
88    ///
89    /// We call this function `inverse` for simplicity's sake; it doesn't always
90    /// correspond to the mathematical notion of "inverse." However, in
91    /// conjunction with checks to `preserves_uniqueness` you can determine
92    /// which type of inverse we return.
93    ///
94    /// [inverse]: https://en.wikipedia.org/wiki/Inverse_function
95    fn inverse(&self) -> Option<crate::UnaryFunc>;
96
97    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
98    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
99    /// determine the range of possible outputs just by mapping the endpoints.
100    ///
101    /// This property describes the behaviour of the function over ranges where the function is defined:
102    /// ie. the argument and the result are non-error datums.
103    fn is_monotone(&self) -> bool;
104
105    /// Returns true if the function does no actual work, but is merely a type-tracking cast.
106    fn is_eliminable_cast(&self) -> bool;
107}
108
109/// A description of an SQL unary function that operates on eagerly evaluated expressions
110pub trait EagerUnaryFunc {
111    type Input<'a>: InputDatumType<'a, EvalError>;
112    type Output<'a>: OutputDatumType<'a, EvalError>;
113
114    fn call<'a>(&self, input: Self::Input<'a>) -> Self::Output<'a>;
115
116    /// The output SqlColumnType of this function
117    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
118
119    /// The output of this function as a representation type.
120    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
121        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
122    }
123
124    /// Whether this function will produce NULL on NULL input
125    fn propagates_nulls(&self) -> bool {
126        // If the input is not nullable then nulls are propagated
127        !Self::Input::<'_>::nullable()
128    }
129
130    /// Whether this function will produce NULL on non-NULL input
131    fn introduces_nulls(&self) -> bool {
132        // If the output is nullable then nulls can be introduced
133        Self::Output::<'_>::nullable()
134    }
135
136    /// Whether this function could produce an error
137    fn could_error(&self) -> bool {
138        Self::Output::<'_>::fallible()
139    }
140
141    /// Whether this function preserves uniqueness
142    fn preserves_uniqueness(&self) -> bool {
143        false
144    }
145
146    fn inverse(&self) -> Option<crate::UnaryFunc> {
147        None
148    }
149
150    fn is_monotone(&self) -> bool {
151        false
152    }
153
154    fn is_eliminable_cast(&self) -> bool {
155        false
156    }
157}
158
159impl<T: EagerUnaryFunc> LazyUnaryFunc for T {
160    fn eval<'a>(
161        &'a self,
162        datums: &[Datum<'a>],
163        temp_storage: &'a RowArena,
164        a: &'a impl Eval,
165    ) -> Result<Datum<'a>, EvalError> {
166        match T::Input::<'_>::try_from_result(a.eval(datums, temp_storage)) {
167            // If we can convert to the input type then we call the function
168            Ok(input) => self.call(input).into_result(temp_storage),
169            // If we can't and we got a non-null datum something went wrong in the planner
170            Err(Ok(datum)) if !datum.is_null() => {
171                Err(EvalError::Internal("invalid input type".into()))
172            }
173            // Otherwise we just propagate NULLs and errors
174            Err(res) => res,
175        }
176    }
177
178    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
179        self.output_sql_type(input_type)
180    }
181
182    fn propagates_nulls(&self) -> bool {
183        self.propagates_nulls()
184    }
185
186    fn introduces_nulls(&self) -> bool {
187        self.introduces_nulls()
188    }
189
190    fn could_error(&self) -> bool {
191        self.could_error()
192    }
193
194    fn preserves_uniqueness(&self) -> bool {
195        self.preserves_uniqueness()
196    }
197
198    fn inverse(&self) -> Option<crate::UnaryFunc> {
199        self.inverse()
200    }
201
202    fn is_monotone(&self) -> bool {
203        self.is_monotone()
204    }
205
206    fn is_eliminable_cast(&self) -> bool {
207        self.is_eliminable_cast()
208    }
209}
210
211derive_unary!(
212    Not,
213    IsNull,
214    IsTrue,
215    IsFalse,
216    BitNotInt16,
217    BitNotInt32,
218    BitNotInt64,
219    BitNotUint16,
220    BitNotUint32,
221    BitNotUint64,
222    NegInt16,
223    NegInt32,
224    NegInt64,
225    NegFloat32,
226    NegFloat64,
227    NegNumeric,
228    NegInterval,
229    SqrtFloat64,
230    SqrtNumeric,
231    CbrtFloat64,
232    AbsInt16,
233    AbsInt32,
234    AbsInt64,
235    AbsFloat32,
236    AbsFloat64,
237    AbsNumeric,
238    CastBoolToString,
239    CastBoolToStringNonstandard,
240    CastBoolToInt32,
241    CastBoolToInt64,
242    CastInt16ToFloat32,
243    CastInt16ToFloat64,
244    CastInt16ToInt32,
245    CastInt16ToInt64,
246    CastInt16ToUint16,
247    CastInt16ToUint32,
248    CastInt16ToUint64,
249    CastInt16ToString,
250    CastInt2VectorToArray,
251    CastInt32ToBool,
252    CastInt32ToFloat32,
253    CastInt32ToFloat64,
254    CastInt32ToOid,
255    CastInt32ToPgLegacyChar,
256    CastInt32ToInt16,
257    CastInt32ToInt64,
258    CastInt32ToUint16,
259    CastInt32ToUint32,
260    CastInt32ToUint64,
261    CastInt32ToString,
262    CastOidToInt32,
263    CastOidToInt64,
264    CastOidToString,
265    CastOidToRegClass,
266    CastRegClassToOid,
267    CastOidToRegProc,
268    CastRegProcToOid,
269    CastOidToRegType,
270    CastRegTypeToOid,
271    CastInt64ToInt16,
272    CastInt64ToInt32,
273    CastInt64ToUint16,
274    CastInt64ToUint32,
275    CastInt64ToUint64,
276    CastInt16ToNumeric,
277    CastInt32ToNumeric,
278    CastInt64ToBool,
279    CastInt64ToNumeric,
280    CastInt64ToFloat32,
281    CastInt64ToFloat64,
282    CastInt64ToOid,
283    CastInt64ToString,
284    CastUint16ToUint32,
285    CastUint16ToUint64,
286    CastUint16ToInt16,
287    CastUint16ToInt32,
288    CastUint16ToInt64,
289    CastUint16ToNumeric,
290    CastUint16ToFloat32,
291    CastUint16ToFloat64,
292    CastUint16ToString,
293    CastUint32ToUint16,
294    CastUint32ToUint64,
295    CastUint32ToInt16,
296    CastUint32ToInt32,
297    CastUint32ToInt64,
298    CastUint32ToNumeric,
299    CastUint32ToFloat32,
300    CastUint32ToFloat64,
301    CastUint32ToString,
302    CastUint64ToUint16,
303    CastUint64ToUint32,
304    CastUint64ToInt16,
305    CastUint64ToInt32,
306    CastUint64ToInt64,
307    CastUint64ToNumeric,
308    CastUint64ToFloat32,
309    CastUint64ToFloat64,
310    CastUint64ToString,
311    CastFloat32ToInt16,
312    CastFloat32ToInt32,
313    CastFloat32ToInt64,
314    CastFloat32ToUint16,
315    CastFloat32ToUint32,
316    CastFloat32ToUint64,
317    CastFloat32ToFloat64,
318    CastFloat32ToString,
319    CastFloat32ToNumeric,
320    CastFloat64ToNumeric,
321    CastFloat64ToInt16,
322    CastFloat64ToInt32,
323    CastFloat64ToInt64,
324    CastFloat64ToUint16,
325    CastFloat64ToUint32,
326    CastFloat64ToUint64,
327    CastFloat64ToFloat32,
328    CastFloat64ToString,
329    CastNumericToFloat32,
330    CastNumericToFloat64,
331    CastNumericToInt16,
332    CastNumericToInt32,
333    CastNumericToInt64,
334    CastNumericToUint16,
335    CastNumericToUint32,
336    CastNumericToUint64,
337    CastNumericToString,
338    CastMzTimestampToString,
339    CastMzTimestampToTimestamp,
340    CastMzTimestampToTimestampTz,
341    CastStringToMzTimestamp,
342    CastUint64ToMzTimestamp,
343    CastUint32ToMzTimestamp,
344    CastInt64ToMzTimestamp,
345    CastInt32ToMzTimestamp,
346    CastNumericToMzTimestamp,
347    CastTimestampToMzTimestamp,
348    CastTimestampTzToMzTimestamp,
349    CastDateToMzTimestamp,
350    CastStringToBool,
351    CastStringToPgLegacyChar,
352    CastStringToPgLegacyName,
353    CastStringToBytes,
354    CastStringToInt16,
355    CastStringToInt32,
356    CastStringToInt64,
357    CastStringToUint16,
358    CastStringToUint32,
359    CastStringToUint64,
360    CastStringToInt2Vector,
361    CastStringToOid,
362    CastStringToFloat32,
363    CastStringToFloat64,
364    CastStringToDate,
365    CastStringToArray,
366    CastStringToList,
367    CastStringToMap,
368    CastStringToRange,
369    CastStringToTime,
370    CastStringToTimestamp,
371    CastStringToTimestampTz,
372    CastStringToInterval,
373    CastStringToNumeric,
374    CastStringToUuid,
375    CastStringToChar,
376    PadChar,
377    CastStringToVarChar,
378    CastCharToString,
379    CastVarCharToString,
380    CastDateToTimestamp,
381    CastDateToTimestampTz,
382    CastDateToString,
383    CastTimeToInterval,
384    CastTimeToString,
385    CastIntervalToString,
386    CastIntervalToTime,
387    CastTimestampToDate,
388    AdjustTimestampPrecision,
389    CastTimestampToTimestampTz,
390    CastTimestampToString,
391    CastTimestampToTime,
392    CastTimestampTzToDate,
393    CastTimestampTzToTimestamp,
394    AdjustTimestampTzPrecision,
395    CastTimestampTzToString,
396    CastTimestampTzToTime,
397    CastPgLegacyCharToString,
398    CastPgLegacyCharToChar,
399    CastPgLegacyCharToVarChar,
400    CastPgLegacyCharToInt32,
401    CastBytesToString,
402    CastStringToJsonb,
403    CastJsonbToString,
404    CastJsonbableToJsonb,
405    CastJsonbToInt16,
406    CastJsonbToInt32,
407    CastJsonbToInt64,
408    CastJsonbToFloat32,
409    CastJsonbToFloat64,
410    CastJsonbToNumeric,
411    CastJsonbToBool,
412    CastUuidToString,
413    CastRecordToString,
414    CastRecord1ToRecord2,
415    CastArrayToArray,
416    CastArrayToJsonb,
417    CastArrayToString,
418    CastListToString,
419    CastListToJsonb,
420    CastList1ToList2,
421    CastArrayToListOneDim,
422    CastMapToString,
423    CastInt2VectorToString,
424    CastRangeToString,
425    CeilFloat32,
426    CeilFloat64,
427    CeilNumeric,
428    FloorFloat32,
429    FloorFloat64,
430    FloorNumeric,
431    Ascii,
432    BitCountBytes,
433    BitLengthBytes,
434    BitLengthString,
435    ByteLengthBytes,
436    ByteLengthString,
437    CharLength,
438    Chr,
439    IsLikeMatch,
440    IsRegexpMatch,
441    RegexpMatch,
442    ExtractInterval,
443    ExtractTime,
444    ExtractTimestamp,
445    ExtractTimestampTz,
446    ExtractDate,
447    DatePartInterval,
448    DatePartTime,
449    DatePartTimestamp,
450    DatePartTimestampTz,
451    DateTruncTimestamp,
452    DateTruncTimestampTz,
453    TimezoneTimestamp,
454    TimezoneTimestampTz,
455    TimezoneTime,
456    ToTimestamp,
457    ToCharTimestamp,
458    ToCharTimestampTz,
459    JustifyDays,
460    JustifyHours,
461    JustifyInterval,
462    JsonbArrayLength,
463    JsonbTypeof,
464    JsonbStripNulls,
465    JsonbPretty,
466    ParseCatalogCreateSql,
467    ParseCatalogId,
468    ParseCatalogPrivileges,
469    RedactSql,
470    RoundFloat32,
471    RoundFloat64,
472    RoundNumeric,
473    TruncFloat32,
474    TruncFloat64,
475    TruncNumeric,
476    TrimWhitespace,
477    TrimLeadingWhitespace,
478    TrimTrailingWhitespace,
479    Initcap,
480    RecordGet,
481    ListLength,
482    MapLength,
483    MapBuildFromRecordList,
484    Upper,
485    Lower,
486    Cos,
487    Acos,
488    Cosh,
489    Acosh,
490    Sin,
491    Asin,
492    Sinh,
493    Asinh,
494    Tan,
495    Atan,
496    Tanh,
497    Atanh,
498    Cot,
499    Degrees,
500    Radians,
501    Log10,
502    Log10Numeric,
503    Ln,
504    LnNumeric,
505    Exp,
506    ExpNumeric,
507    Sleep,
508    Panic,
509    AdjustNumericScale,
510    PgColumnSize,
511    MzRowSize,
512    MzTypeName,
513    StepMzTimestamp,
514    RangeLower,
515    RangeUpper,
516    RangeEmpty,
517    RangeLowerInc,
518    RangeUpperInc,
519    RangeLowerInf,
520    RangeUpperInf,
521    MzAclItemGrantor,
522    MzAclItemGrantee,
523    MzAclItemPrivileges,
524    MzFormatPrivileges,
525    MzValidatePrivileges,
526    MzValidateRolePrivilege,
527    AclItemGrantor,
528    AclItemGrantee,
529    AclItemPrivileges,
530    QuoteIdent,
531    TryParseMonotonicIso8601Timestamp,
532    RegexpSplitToArray,
533    PgSizePretty,
534    Crc32Bytes,
535    Crc32String,
536    KafkaMurmur2Bytes,
537    KafkaMurmur2String,
538    SeahashBytes,
539    SeahashString,
540    Reverse
541);
542
543impl UnaryFunc {
544    /// If the unary_func represents "IS X", return X.
545    ///
546    /// A helper method for being able to print Not(IsX) as IS NOT X.
547    pub fn is(&self) -> Option<&'static str> {
548        match self {
549            UnaryFunc::IsNull(_) => Some("NULL"),
550            UnaryFunc::IsTrue(_) => Some("TRUE"),
551            UnaryFunc::IsFalse(_) => Some("FALSE"),
552            _ => None,
553        }
554    }
555}
556
557#[cfg(test)]
558mod test {
559    use itertools::Itertools;
560    use mz_repr::{PropDatum, SqlScalarType};
561
562    use crate::{MirScalarExpr, like_pattern};
563
564    use super::*;
565
566    #[mz_ore::test]
567    fn test_could_error() {
568        for func in [
569            UnaryFunc::IsNull(IsNull),
570            UnaryFunc::CastVarCharToString(CastVarCharToString),
571            UnaryFunc::Not(Not),
572            UnaryFunc::IsLikeMatch(IsLikeMatch(like_pattern::compile("%hi%", false).unwrap())),
573        ] {
574            assert!(!func.could_error())
575        }
576    }
577
578    #[mz_ore::test]
579    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
580    fn test_is_monotone() {
581        use proptest::prelude::*;
582
583        /// Asserts that the function is either monotonically increasing or decreasing over
584        /// the given sets of arguments.
585        fn assert_monotone<'a, const N: usize>(
586            expr: &MirScalarExpr,
587            arena: &'a RowArena,
588            datums: &[[Datum<'a>; N]],
589        ) {
590            // TODO: assertions for nulls, errors
591            let Ok(results) = datums
592                .iter()
593                .map(|args| expr.eval(args.as_slice(), arena))
594                .collect::<Result<Vec<_>, _>>()
595            else {
596                return;
597            };
598
599            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
600            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
601            assert!(
602                forward || reverse,
603                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
604            );
605        }
606
607        fn proptest_unary<'a>(
608            func: UnaryFunc,
609            arena: &'a RowArena,
610            arg: impl Strategy<Value = PropDatum>,
611        ) {
612            let is_monotone = func.is_monotone();
613            let expr = MirScalarExpr::CallUnary {
614                func,
615                expr: Box::new(MirScalarExpr::column(0)),
616            };
617            if is_monotone {
618                proptest!(|(
619                    mut arg in proptest::array::uniform3(arg),
620                )| {
621                    arg.sort();
622                    let args: Vec<_> = arg.iter().map(|a| [Datum::from(a)]).collect();
623                    assert_monotone(&expr, arena, &args);
624                });
625            }
626        }
627
628        let interesting_i32s: Vec<Datum<'static>> =
629            SqlScalarType::Int32.interesting_datums().collect();
630        let i32_datums = proptest::strategy::Union::new([
631            any::<i32>().prop_map(PropDatum::Int32).boxed(),
632            (0..interesting_i32s.len())
633                .prop_map(move |i| {
634                    let Datum::Int32(val) = interesting_i32s[i] else {
635                        unreachable!("interesting int32 has non-i32s")
636                    };
637                    PropDatum::Int32(val)
638                })
639                .boxed(),
640            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
641        ]);
642
643        let arena = RowArena::new();
644
645        // It would be interesting to test all funcs here, but we currently need to hardcode
646        // the generators for the argument types, which makes this tedious. Choose an interesting
647        // subset for now.
648        proptest_unary(
649            UnaryFunc::CastInt32ToNumeric(CastInt32ToNumeric(None)),
650            &arena,
651            &i32_datums,
652        );
653        proptest_unary(
654            UnaryFunc::CastInt32ToUint16(CastInt32ToUint16),
655            &arena,
656            &i32_datums,
657        );
658        proptest_unary(
659            UnaryFunc::CastInt32ToString(CastInt32ToString),
660            &arena,
661            &i32_datums,
662        );
663    }
664}