Skip to main content

mz_expr/scalar/func/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Definition of the [`UnaryFunc`] enum and related machinery.
15
16use std::{fmt, str};
17
18use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
19
20use crate::scalar::func::RedactSql;
21use crate::scalar::func::impls::*;
22use crate::{EvalError, MirScalarExpr};
23
24/// A description of an SQL unary function that has the ability to lazy evaluate its arguments
25// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
26pub trait LazyUnaryFunc {
27    fn eval<'a>(
28        &'a self,
29        datums: &[Datum<'a>],
30        temp_storage: &'a RowArena,
31        a: &'a MirScalarExpr,
32    ) -> Result<Datum<'a>, EvalError>;
33
34    /// The output SqlColumnType of this function.
35    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
36
37    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
38        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
39    }
40
41    /// Whether this function will produce NULL on NULL input.
42    fn propagates_nulls(&self) -> bool;
43
44    /// Whether this function will produce NULL on non-NULL input.
45    fn introduces_nulls(&self) -> bool;
46
47    /// Whether this function might error on non-error input.
48    fn could_error(&self) -> bool {
49        // NB: override this for functions that never error.
50        true
51    }
52
53    /// Whether this function preserves uniqueness.
54    ///
55    /// Uniqueness is preserved when `if f(x) = f(y) then x = y` is true. This
56    /// is used by the optimizer when a guarantee can be made that a collection
57    /// with unique items will stay unique when mapped by this function.
58    ///
59    /// Note that error results are not covered: Even with `preserves_uniqueness = true`, it can
60    /// happen that two different inputs produce the same error result. (e.g., in case of a
61    /// narrowing cast)
62    ///
63    /// Functions should conservatively return `false` unless they are certain
64    /// the above property is true.
65    fn preserves_uniqueness(&self) -> bool;
66
67    /// The [inverse] of this function, if it has one and we have determined it.
68    ///
69    /// The optimizer _can_ use this information when selecting indexes, e.g. an
70    /// indexed column has a cast applied to it, by moving the right inverse of
71    /// the cast to another value, we can select the indexed column.
72    ///
73    /// Note that a value of `None` does not imply that the inverse does not
74    /// exist; it could also mean we have not yet invested the energy in
75    /// representing it. For example, in the case of complex casts, such as
76    /// between two list types, we could determine the right inverse, but doing
77    /// so is not immediately necessary as this information is only used by the
78    /// optimizer.
79    ///
80    /// ## Right vs. left vs. inverses
81    /// - Right inverses are when the inverse function preserves uniqueness.
82    ///   These are the functions that the optimizer uses to move casts between
83    ///   expressions.
84    /// - Left inverses are when the function itself preserves uniqueness.
85    /// - Inverses are when a function is both a right and a left inverse (e.g.,
86    ///   bit_not_int64 is both a right and left inverse of itself).
87    ///
88    /// We call this function `inverse` for simplicity's sake; it doesn't always
89    /// correspond to the mathematical notion of "inverse." However, in
90    /// conjunction with checks to `preserves_uniqueness` you can determine
91    /// which type of inverse we return.
92    ///
93    /// [inverse]: https://en.wikipedia.org/wiki/Inverse_function
94    fn inverse(&self) -> Option<crate::UnaryFunc>;
95
96    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
97    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
98    /// determine the range of possible outputs just by mapping the endpoints.
99    ///
100    /// This property describes the behaviour of the function over ranges where the function is defined:
101    /// ie. the argument and the result are non-error datums.
102    fn is_monotone(&self) -> bool;
103
104    /// Returns true if the function does no actual work, but is merely a type-tracking cast.
105    fn is_eliminable_cast(&self) -> bool;
106}
107
108/// A description of an SQL unary function that operates on eagerly evaluated expressions
109pub trait EagerUnaryFunc {
110    type Input<'a>: InputDatumType<'a, EvalError>;
111    type Output<'a>: OutputDatumType<'a, EvalError>;
112
113    fn call<'a>(&self, input: Self::Input<'a>) -> Self::Output<'a>;
114
115    /// The output SqlColumnType of this function
116    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
117
118    /// The output of this function as a representation type.
119    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
120        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
121    }
122
123    /// Whether this function will produce NULL on NULL input
124    fn propagates_nulls(&self) -> bool {
125        // If the input is not nullable then nulls are propagated
126        !Self::Input::<'_>::nullable()
127    }
128
129    /// Whether this function will produce NULL on non-NULL input
130    fn introduces_nulls(&self) -> bool {
131        // If the output is nullable then nulls can be introduced
132        Self::Output::<'_>::nullable()
133    }
134
135    /// Whether this function could produce an error
136    fn could_error(&self) -> bool {
137        Self::Output::<'_>::fallible()
138    }
139
140    /// Whether this function preserves uniqueness
141    fn preserves_uniqueness(&self) -> bool {
142        false
143    }
144
145    fn inverse(&self) -> Option<crate::UnaryFunc> {
146        None
147    }
148
149    fn is_monotone(&self) -> bool {
150        false
151    }
152
153    fn is_eliminable_cast(&self) -> bool {
154        false
155    }
156}
157
158impl<T: EagerUnaryFunc> LazyUnaryFunc for T {
159    fn eval<'a>(
160        &'a self,
161        datums: &[Datum<'a>],
162        temp_storage: &'a RowArena,
163        a: &'a MirScalarExpr,
164    ) -> Result<Datum<'a>, EvalError> {
165        match T::Input::<'_>::try_from_result(a.eval(datums, temp_storage)) {
166            // If we can convert to the input type then we call the function
167            Ok(input) => self.call(input).into_result(temp_storage),
168            // If we can't and we got a non-null datum something went wrong in the planner
169            Err(Ok(datum)) if !datum.is_null() => {
170                Err(EvalError::Internal("invalid input type".into()))
171            }
172            // Otherwise we just propagate NULLs and errors
173            Err(res) => res,
174        }
175    }
176
177    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
178        self.output_sql_type(input_type)
179    }
180
181    fn propagates_nulls(&self) -> bool {
182        self.propagates_nulls()
183    }
184
185    fn introduces_nulls(&self) -> bool {
186        self.introduces_nulls()
187    }
188
189    fn could_error(&self) -> bool {
190        self.could_error()
191    }
192
193    fn preserves_uniqueness(&self) -> bool {
194        self.preserves_uniqueness()
195    }
196
197    fn inverse(&self) -> Option<crate::UnaryFunc> {
198        self.inverse()
199    }
200
201    fn is_monotone(&self) -> bool {
202        self.is_monotone()
203    }
204
205    fn is_eliminable_cast(&self) -> bool {
206        self.is_eliminable_cast()
207    }
208}
209
210derive_unary!(
211    Not,
212    IsNull,
213    IsTrue,
214    IsFalse,
215    BitNotInt16,
216    BitNotInt32,
217    BitNotInt64,
218    BitNotUint16,
219    BitNotUint32,
220    BitNotUint64,
221    NegInt16,
222    NegInt32,
223    NegInt64,
224    NegFloat32,
225    NegFloat64,
226    NegNumeric,
227    NegInterval,
228    SqrtFloat64,
229    SqrtNumeric,
230    CbrtFloat64,
231    AbsInt16,
232    AbsInt32,
233    AbsInt64,
234    AbsFloat32,
235    AbsFloat64,
236    AbsNumeric,
237    CastBoolToString,
238    CastBoolToStringNonstandard,
239    CastBoolToInt32,
240    CastBoolToInt64,
241    CastInt16ToFloat32,
242    CastInt16ToFloat64,
243    CastInt16ToInt32,
244    CastInt16ToInt64,
245    CastInt16ToUint16,
246    CastInt16ToUint32,
247    CastInt16ToUint64,
248    CastInt16ToString,
249    CastInt2VectorToArray,
250    CastInt32ToBool,
251    CastInt32ToFloat32,
252    CastInt32ToFloat64,
253    CastInt32ToOid,
254    CastInt32ToPgLegacyChar,
255    CastInt32ToInt16,
256    CastInt32ToInt64,
257    CastInt32ToUint16,
258    CastInt32ToUint32,
259    CastInt32ToUint64,
260    CastInt32ToString,
261    CastOidToInt32,
262    CastOidToInt64,
263    CastOidToString,
264    CastOidToRegClass,
265    CastRegClassToOid,
266    CastOidToRegProc,
267    CastRegProcToOid,
268    CastOidToRegType,
269    CastRegTypeToOid,
270    CastInt64ToInt16,
271    CastInt64ToInt32,
272    CastInt64ToUint16,
273    CastInt64ToUint32,
274    CastInt64ToUint64,
275    CastInt16ToNumeric,
276    CastInt32ToNumeric,
277    CastInt64ToBool,
278    CastInt64ToNumeric,
279    CastInt64ToFloat32,
280    CastInt64ToFloat64,
281    CastInt64ToOid,
282    CastInt64ToString,
283    CastUint16ToUint32,
284    CastUint16ToUint64,
285    CastUint16ToInt16,
286    CastUint16ToInt32,
287    CastUint16ToInt64,
288    CastUint16ToNumeric,
289    CastUint16ToFloat32,
290    CastUint16ToFloat64,
291    CastUint16ToString,
292    CastUint32ToUint16,
293    CastUint32ToUint64,
294    CastUint32ToInt16,
295    CastUint32ToInt32,
296    CastUint32ToInt64,
297    CastUint32ToNumeric,
298    CastUint32ToFloat32,
299    CastUint32ToFloat64,
300    CastUint32ToString,
301    CastUint64ToUint16,
302    CastUint64ToUint32,
303    CastUint64ToInt16,
304    CastUint64ToInt32,
305    CastUint64ToInt64,
306    CastUint64ToNumeric,
307    CastUint64ToFloat32,
308    CastUint64ToFloat64,
309    CastUint64ToString,
310    CastFloat32ToInt16,
311    CastFloat32ToInt32,
312    CastFloat32ToInt64,
313    CastFloat32ToUint16,
314    CastFloat32ToUint32,
315    CastFloat32ToUint64,
316    CastFloat32ToFloat64,
317    CastFloat32ToString,
318    CastFloat32ToNumeric,
319    CastFloat64ToNumeric,
320    CastFloat64ToInt16,
321    CastFloat64ToInt32,
322    CastFloat64ToInt64,
323    CastFloat64ToUint16,
324    CastFloat64ToUint32,
325    CastFloat64ToUint64,
326    CastFloat64ToFloat32,
327    CastFloat64ToString,
328    CastNumericToFloat32,
329    CastNumericToFloat64,
330    CastNumericToInt16,
331    CastNumericToInt32,
332    CastNumericToInt64,
333    CastNumericToUint16,
334    CastNumericToUint32,
335    CastNumericToUint64,
336    CastNumericToString,
337    CastMzTimestampToString,
338    CastMzTimestampToTimestamp,
339    CastMzTimestampToTimestampTz,
340    CastStringToMzTimestamp,
341    CastUint64ToMzTimestamp,
342    CastUint32ToMzTimestamp,
343    CastInt64ToMzTimestamp,
344    CastInt32ToMzTimestamp,
345    CastNumericToMzTimestamp,
346    CastTimestampToMzTimestamp,
347    CastTimestampTzToMzTimestamp,
348    CastDateToMzTimestamp,
349    CastStringToBool,
350    CastStringToPgLegacyChar,
351    CastStringToPgLegacyName,
352    CastStringToBytes,
353    CastStringToInt16,
354    CastStringToInt32,
355    CastStringToInt64,
356    CastStringToUint16,
357    CastStringToUint32,
358    CastStringToUint64,
359    CastStringToInt2Vector,
360    CastStringToOid,
361    CastStringToFloat32,
362    CastStringToFloat64,
363    CastStringToDate,
364    CastStringToArray,
365    CastStringToList,
366    CastStringToMap,
367    CastStringToRange,
368    CastStringToTime,
369    CastStringToTimestamp,
370    CastStringToTimestampTz,
371    CastStringToInterval,
372    CastStringToNumeric,
373    CastStringToUuid,
374    CastStringToChar,
375    PadChar,
376    CastStringToVarChar,
377    CastCharToString,
378    CastVarCharToString,
379    CastDateToTimestamp,
380    CastDateToTimestampTz,
381    CastDateToString,
382    CastTimeToInterval,
383    CastTimeToString,
384    CastIntervalToString,
385    CastIntervalToTime,
386    CastTimestampToDate,
387    AdjustTimestampPrecision,
388    CastTimestampToTimestampTz,
389    CastTimestampToString,
390    CastTimestampToTime,
391    CastTimestampTzToDate,
392    CastTimestampTzToTimestamp,
393    AdjustTimestampTzPrecision,
394    CastTimestampTzToString,
395    CastTimestampTzToTime,
396    CastPgLegacyCharToString,
397    CastPgLegacyCharToChar,
398    CastPgLegacyCharToVarChar,
399    CastPgLegacyCharToInt32,
400    CastBytesToString,
401    CastStringToJsonb,
402    CastJsonbToString,
403    CastJsonbableToJsonb,
404    CastJsonbToInt16,
405    CastJsonbToInt32,
406    CastJsonbToInt64,
407    CastJsonbToFloat32,
408    CastJsonbToFloat64,
409    CastJsonbToNumeric,
410    CastJsonbToBool,
411    CastUuidToString,
412    CastRecordToString,
413    CastRecord1ToRecord2,
414    CastArrayToArray,
415    CastArrayToJsonb,
416    CastArrayToString,
417    CastListToString,
418    CastListToJsonb,
419    CastList1ToList2,
420    CastArrayToListOneDim,
421    CastMapToString,
422    CastInt2VectorToString,
423    CastRangeToString,
424    CeilFloat32,
425    CeilFloat64,
426    CeilNumeric,
427    FloorFloat32,
428    FloorFloat64,
429    FloorNumeric,
430    Ascii,
431    BitCountBytes,
432    BitLengthBytes,
433    BitLengthString,
434    ByteLengthBytes,
435    ByteLengthString,
436    CharLength,
437    Chr,
438    IsLikeMatch,
439    IsRegexpMatch,
440    RegexpMatch,
441    ExtractInterval,
442    ExtractTime,
443    ExtractTimestamp,
444    ExtractTimestampTz,
445    ExtractDate,
446    DatePartInterval,
447    DatePartTime,
448    DatePartTimestamp,
449    DatePartTimestampTz,
450    DateTruncTimestamp,
451    DateTruncTimestampTz,
452    TimezoneTimestamp,
453    TimezoneTimestampTz,
454    TimezoneTime,
455    ToTimestamp,
456    ToCharTimestamp,
457    ToCharTimestampTz,
458    JustifyDays,
459    JustifyHours,
460    JustifyInterval,
461    JsonbArrayLength,
462    JsonbTypeof,
463    JsonbStripNulls,
464    JsonbPretty,
465    ParseCatalogCreateSql,
466    ParseCatalogId,
467    ParseCatalogPrivileges,
468    RedactSql,
469    RoundFloat32,
470    RoundFloat64,
471    RoundNumeric,
472    TruncFloat32,
473    TruncFloat64,
474    TruncNumeric,
475    TrimWhitespace,
476    TrimLeadingWhitespace,
477    TrimTrailingWhitespace,
478    Initcap,
479    RecordGet,
480    ListLength,
481    MapLength,
482    MapBuildFromRecordList,
483    Upper,
484    Lower,
485    Cos,
486    Acos,
487    Cosh,
488    Acosh,
489    Sin,
490    Asin,
491    Sinh,
492    Asinh,
493    Tan,
494    Atan,
495    Tanh,
496    Atanh,
497    Cot,
498    Degrees,
499    Radians,
500    Log10,
501    Log10Numeric,
502    Ln,
503    LnNumeric,
504    Exp,
505    ExpNumeric,
506    Sleep,
507    Panic,
508    AdjustNumericScale,
509    PgColumnSize,
510    MzRowSize,
511    MzTypeName,
512    StepMzTimestamp,
513    RangeLower,
514    RangeUpper,
515    RangeEmpty,
516    RangeLowerInc,
517    RangeUpperInc,
518    RangeLowerInf,
519    RangeUpperInf,
520    MzAclItemGrantor,
521    MzAclItemGrantee,
522    MzAclItemPrivileges,
523    MzFormatPrivileges,
524    MzValidatePrivileges,
525    MzValidateRolePrivilege,
526    AclItemGrantor,
527    AclItemGrantee,
528    AclItemPrivileges,
529    QuoteIdent,
530    TryParseMonotonicIso8601Timestamp,
531    RegexpSplitToArray,
532    PgSizePretty,
533    Crc32Bytes,
534    Crc32String,
535    KafkaMurmur2Bytes,
536    KafkaMurmur2String,
537    SeahashBytes,
538    SeahashString,
539    Reverse
540);
541
542impl UnaryFunc {
543    /// If the unary_func represents "IS X", return X.
544    ///
545    /// A helper method for being able to print Not(IsX) as IS NOT X.
546    pub fn is(&self) -> Option<&'static str> {
547        match self {
548            UnaryFunc::IsNull(_) => Some("NULL"),
549            UnaryFunc::IsTrue(_) => Some("TRUE"),
550            UnaryFunc::IsFalse(_) => Some("FALSE"),
551            _ => None,
552        }
553    }
554}
555
556#[cfg(test)]
557mod test {
558    use itertools::Itertools;
559    use mz_repr::{PropDatum, SqlScalarType};
560
561    use crate::like_pattern;
562
563    use super::*;
564
565    #[mz_ore::test]
566    fn test_could_error() {
567        for func in [
568            UnaryFunc::IsNull(IsNull),
569            UnaryFunc::CastVarCharToString(CastVarCharToString),
570            UnaryFunc::Not(Not),
571            UnaryFunc::IsLikeMatch(IsLikeMatch(like_pattern::compile("%hi%", false).unwrap())),
572        ] {
573            assert!(!func.could_error())
574        }
575    }
576
577    #[mz_ore::test]
578    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
579    fn test_is_monotone() {
580        use proptest::prelude::*;
581
582        /// Asserts that the function is either monotonically increasing or decreasing over
583        /// the given sets of arguments.
584        fn assert_monotone<'a, const N: usize>(
585            expr: &MirScalarExpr,
586            arena: &'a RowArena,
587            datums: &[[Datum<'a>; N]],
588        ) {
589            // TODO: assertions for nulls, errors
590            let Ok(results) = datums
591                .iter()
592                .map(|args| expr.eval(args.as_slice(), arena))
593                .collect::<Result<Vec<_>, _>>()
594            else {
595                return;
596            };
597
598            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
599            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
600            assert!(
601                forward || reverse,
602                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
603            );
604        }
605
606        fn proptest_unary<'a>(
607            func: UnaryFunc,
608            arena: &'a RowArena,
609            arg: impl Strategy<Value = PropDatum>,
610        ) {
611            let is_monotone = func.is_monotone();
612            let expr = MirScalarExpr::CallUnary {
613                func,
614                expr: Box::new(MirScalarExpr::column(0)),
615            };
616            if is_monotone {
617                proptest!(|(
618                    mut arg in proptest::array::uniform3(arg),
619                )| {
620                    arg.sort();
621                    let args: Vec<_> = arg.iter().map(|a| [Datum::from(a)]).collect();
622                    assert_monotone(&expr, arena, &args);
623                });
624            }
625        }
626
627        let interesting_i32s: Vec<Datum<'static>> =
628            SqlScalarType::Int32.interesting_datums().collect();
629        let i32_datums = proptest::strategy::Union::new([
630            any::<i32>().prop_map(PropDatum::Int32).boxed(),
631            (0..interesting_i32s.len())
632                .prop_map(move |i| {
633                    let Datum::Int32(val) = interesting_i32s[i] else {
634                        unreachable!("interesting int32 has non-i32s")
635                    };
636                    PropDatum::Int32(val)
637                })
638                .boxed(),
639            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
640        ]);
641
642        let arena = RowArena::new();
643
644        // It would be interesting to test all funcs here, but we currently need to hardcode
645        // the generators for the argument types, which makes this tedious. Choose an interesting
646        // subset for now.
647        proptest_unary(
648            UnaryFunc::CastInt32ToNumeric(CastInt32ToNumeric(None)),
649            &arena,
650            &i32_datums,
651        );
652        proptest_unary(
653            UnaryFunc::CastInt32ToUint16(CastInt32ToUint16),
654            &arena,
655            &i32_datums,
656        );
657        proptest_unary(
658            UnaryFunc::CastInt32ToString(CastInt32ToString),
659            &arena,
660            &i32_datums,
661        );
662    }
663}