Skip to main content

mz_expr/scalar/func/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9//
10// Portions of this file are derived from the PostgreSQL project. The original
11// source code is subject to the terms of the PostgreSQL license, a copy of
12// which can be found in the LICENSE file at the root of this repository.
13
14//! Definition of the [`UnaryFunc`] enum and related machinery.
15
16use std::{fmt, str};
17
18use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
19
20use crate::scalar::func::impls::*;
21use crate::{EvalError, MirScalarExpr};
22
23/// A description of an SQL unary function that has the ability to lazy evaluate its arguments
24// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
25pub trait LazyUnaryFunc {
26    fn eval<'a>(
27        &'a self,
28        datums: &[Datum<'a>],
29        temp_storage: &'a RowArena,
30        a: &'a MirScalarExpr,
31    ) -> Result<Datum<'a>, EvalError>;
32
33    /// The output SqlColumnType of this function.
34    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
35
36    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
37        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
38    }
39
40    /// Whether this function will produce NULL on NULL input.
41    fn propagates_nulls(&self) -> bool;
42
43    /// Whether this function will produce NULL on non-NULL input.
44    fn introduces_nulls(&self) -> bool;
45
46    /// Whether this function might error on non-error input.
47    fn could_error(&self) -> bool {
48        // NB: override this for functions that never error.
49        true
50    }
51
52    /// Whether this function preserves uniqueness.
53    ///
54    /// Uniqueness is preserved when `if f(x) = f(y) then x = y` is true. This
55    /// is used by the optimizer when a guarantee can be made that a collection
56    /// with unique items will stay unique when mapped by this function.
57    ///
58    /// Note that error results are not covered: Even with `preserves_uniqueness = true`, it can
59    /// happen that two different inputs produce the same error result. (e.g., in case of a
60    /// narrowing cast)
61    ///
62    /// Functions should conservatively return `false` unless they are certain
63    /// the above property is true.
64    fn preserves_uniqueness(&self) -> bool;
65
66    /// The [inverse] of this function, if it has one and we have determined it.
67    ///
68    /// The optimizer _can_ use this information when selecting indexes, e.g. an
69    /// indexed column has a cast applied to it, by moving the right inverse of
70    /// the cast to another value, we can select the indexed column.
71    ///
72    /// Note that a value of `None` does not imply that the inverse does not
73    /// exist; it could also mean we have not yet invested the energy in
74    /// representing it. For example, in the case of complex casts, such as
75    /// between two list types, we could determine the right inverse, but doing
76    /// so is not immediately necessary as this information is only used by the
77    /// optimizer.
78    ///
79    /// ## Right vs. left vs. inverses
80    /// - Right inverses are when the inverse function preserves uniqueness.
81    ///   These are the functions that the optimizer uses to move casts between
82    ///   expressions.
83    /// - Left inverses are when the function itself preserves uniqueness.
84    /// - Inverses are when a function is both a right and a left inverse (e.g.,
85    ///   bit_not_int64 is both a right and left inverse of itself).
86    ///
87    /// We call this function `inverse` for simplicity's sake; it doesn't always
88    /// correspond to the mathematical notion of "inverse." However, in
89    /// conjunction with checks to `preserves_uniqueness` you can determine
90    /// which type of inverse we return.
91    ///
92    /// [inverse]: https://en.wikipedia.org/wiki/Inverse_function
93    fn inverse(&self) -> Option<crate::UnaryFunc>;
94
95    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
96    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
97    /// determine the range of possible outputs just by mapping the endpoints.
98    ///
99    /// This property describes the behaviour of the function over ranges where the function is defined:
100    /// ie. the argument and the result are non-error datums.
101    fn is_monotone(&self) -> bool;
102}
103
104/// A description of an SQL unary function that operates on eagerly evaluated expressions
105pub trait EagerUnaryFunc {
106    type Input<'a>: InputDatumType<'a, EvalError>;
107    type Output<'a>: OutputDatumType<'a, EvalError>;
108
109    fn call<'a>(&self, input: Self::Input<'a>) -> Self::Output<'a>;
110
111    /// The output SqlColumnType of this function
112    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType;
113
114    /// The output of this function as a representation type.
115    fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
116        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
117    }
118
119    /// Whether this function will produce NULL on NULL input
120    fn propagates_nulls(&self) -> bool {
121        // If the input is not nullable then nulls are propagated
122        !Self::Input::<'_>::nullable()
123    }
124
125    /// Whether this function will produce NULL on non-NULL input
126    fn introduces_nulls(&self) -> bool {
127        // If the output is nullable then nulls can be introduced
128        Self::Output::<'_>::nullable()
129    }
130
131    /// Whether this function could produce an error
132    fn could_error(&self) -> bool {
133        Self::Output::<'_>::fallible()
134    }
135
136    /// Whether this function preserves uniqueness
137    fn preserves_uniqueness(&self) -> bool {
138        false
139    }
140
141    fn inverse(&self) -> Option<crate::UnaryFunc> {
142        None
143    }
144
145    fn is_monotone(&self) -> bool {
146        false
147    }
148}
149
150impl<T: EagerUnaryFunc> LazyUnaryFunc for T {
151    fn eval<'a>(
152        &'a self,
153        datums: &[Datum<'a>],
154        temp_storage: &'a RowArena,
155        a: &'a MirScalarExpr,
156    ) -> Result<Datum<'a>, EvalError> {
157        match T::Input::<'_>::try_from_result(a.eval(datums, temp_storage)) {
158            // If we can convert to the input type then we call the function
159            Ok(input) => self.call(input).into_result(temp_storage),
160            // If we can't and we got a non-null datum something went wrong in the planner
161            Err(Ok(datum)) if !datum.is_null() => {
162                Err(EvalError::Internal("invalid input type".into()))
163            }
164            // Otherwise we just propagate NULLs and errors
165            Err(res) => res,
166        }
167    }
168
169    fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
170        self.output_sql_type(input_type)
171    }
172
173    fn propagates_nulls(&self) -> bool {
174        self.propagates_nulls()
175    }
176
177    fn introduces_nulls(&self) -> bool {
178        self.introduces_nulls()
179    }
180
181    fn could_error(&self) -> bool {
182        self.could_error()
183    }
184
185    fn preserves_uniqueness(&self) -> bool {
186        self.preserves_uniqueness()
187    }
188
189    fn inverse(&self) -> Option<crate::UnaryFunc> {
190        self.inverse()
191    }
192
193    fn is_monotone(&self) -> bool {
194        self.is_monotone()
195    }
196}
197
198derive_unary!(
199    Not,
200    IsNull,
201    IsTrue,
202    IsFalse,
203    BitNotInt16,
204    BitNotInt32,
205    BitNotInt64,
206    BitNotUint16,
207    BitNotUint32,
208    BitNotUint64,
209    NegInt16,
210    NegInt32,
211    NegInt64,
212    NegFloat32,
213    NegFloat64,
214    NegNumeric,
215    NegInterval,
216    SqrtFloat64,
217    SqrtNumeric,
218    CbrtFloat64,
219    AbsInt16,
220    AbsInt32,
221    AbsInt64,
222    AbsFloat32,
223    AbsFloat64,
224    AbsNumeric,
225    CastBoolToString,
226    CastBoolToStringNonstandard,
227    CastBoolToInt32,
228    CastBoolToInt64,
229    CastInt16ToFloat32,
230    CastInt16ToFloat64,
231    CastInt16ToInt32,
232    CastInt16ToInt64,
233    CastInt16ToUint16,
234    CastInt16ToUint32,
235    CastInt16ToUint64,
236    CastInt16ToString,
237    CastInt2VectorToArray,
238    CastInt32ToBool,
239    CastInt32ToFloat32,
240    CastInt32ToFloat64,
241    CastInt32ToOid,
242    CastInt32ToPgLegacyChar,
243    CastInt32ToInt16,
244    CastInt32ToInt64,
245    CastInt32ToUint16,
246    CastInt32ToUint32,
247    CastInt32ToUint64,
248    CastInt32ToString,
249    CastOidToInt32,
250    CastOidToInt64,
251    CastOidToString,
252    CastOidToRegClass,
253    CastRegClassToOid,
254    CastOidToRegProc,
255    CastRegProcToOid,
256    CastOidToRegType,
257    CastRegTypeToOid,
258    CastInt64ToInt16,
259    CastInt64ToInt32,
260    CastInt64ToUint16,
261    CastInt64ToUint32,
262    CastInt64ToUint64,
263    CastInt16ToNumeric,
264    CastInt32ToNumeric,
265    CastInt64ToBool,
266    CastInt64ToNumeric,
267    CastInt64ToFloat32,
268    CastInt64ToFloat64,
269    CastInt64ToOid,
270    CastInt64ToString,
271    CastUint16ToUint32,
272    CastUint16ToUint64,
273    CastUint16ToInt16,
274    CastUint16ToInt32,
275    CastUint16ToInt64,
276    CastUint16ToNumeric,
277    CastUint16ToFloat32,
278    CastUint16ToFloat64,
279    CastUint16ToString,
280    CastUint32ToUint16,
281    CastUint32ToUint64,
282    CastUint32ToInt16,
283    CastUint32ToInt32,
284    CastUint32ToInt64,
285    CastUint32ToNumeric,
286    CastUint32ToFloat32,
287    CastUint32ToFloat64,
288    CastUint32ToString,
289    CastUint64ToUint16,
290    CastUint64ToUint32,
291    CastUint64ToInt16,
292    CastUint64ToInt32,
293    CastUint64ToInt64,
294    CastUint64ToNumeric,
295    CastUint64ToFloat32,
296    CastUint64ToFloat64,
297    CastUint64ToString,
298    CastFloat32ToInt16,
299    CastFloat32ToInt32,
300    CastFloat32ToInt64,
301    CastFloat32ToUint16,
302    CastFloat32ToUint32,
303    CastFloat32ToUint64,
304    CastFloat32ToFloat64,
305    CastFloat32ToString,
306    CastFloat32ToNumeric,
307    CastFloat64ToNumeric,
308    CastFloat64ToInt16,
309    CastFloat64ToInt32,
310    CastFloat64ToInt64,
311    CastFloat64ToUint16,
312    CastFloat64ToUint32,
313    CastFloat64ToUint64,
314    CastFloat64ToFloat32,
315    CastFloat64ToString,
316    CastNumericToFloat32,
317    CastNumericToFloat64,
318    CastNumericToInt16,
319    CastNumericToInt32,
320    CastNumericToInt64,
321    CastNumericToUint16,
322    CastNumericToUint32,
323    CastNumericToUint64,
324    CastNumericToString,
325    CastMzTimestampToString,
326    CastMzTimestampToTimestamp,
327    CastMzTimestampToTimestampTz,
328    CastStringToMzTimestamp,
329    CastUint64ToMzTimestamp,
330    CastUint32ToMzTimestamp,
331    CastInt64ToMzTimestamp,
332    CastInt32ToMzTimestamp,
333    CastNumericToMzTimestamp,
334    CastTimestampToMzTimestamp,
335    CastTimestampTzToMzTimestamp,
336    CastDateToMzTimestamp,
337    CastStringToBool,
338    CastStringToPgLegacyChar,
339    CastStringToPgLegacyName,
340    CastStringToBytes,
341    CastStringToInt16,
342    CastStringToInt32,
343    CastStringToInt64,
344    CastStringToUint16,
345    CastStringToUint32,
346    CastStringToUint64,
347    CastStringToInt2Vector,
348    CastStringToOid,
349    CastStringToFloat32,
350    CastStringToFloat64,
351    CastStringToDate,
352    CastStringToArray,
353    CastStringToList,
354    CastStringToMap,
355    CastStringToRange,
356    CastStringToTime,
357    CastStringToTimestamp,
358    CastStringToTimestampTz,
359    CastStringToInterval,
360    CastStringToNumeric,
361    CastStringToUuid,
362    CastStringToChar,
363    PadChar,
364    CastStringToVarChar,
365    CastCharToString,
366    CastVarCharToString,
367    CastDateToTimestamp,
368    CastDateToTimestampTz,
369    CastDateToString,
370    CastTimeToInterval,
371    CastTimeToString,
372    CastIntervalToString,
373    CastIntervalToTime,
374    CastTimestampToDate,
375    AdjustTimestampPrecision,
376    CastTimestampToTimestampTz,
377    CastTimestampToString,
378    CastTimestampToTime,
379    CastTimestampTzToDate,
380    CastTimestampTzToTimestamp,
381    AdjustTimestampTzPrecision,
382    CastTimestampTzToString,
383    CastTimestampTzToTime,
384    CastPgLegacyCharToString,
385    CastPgLegacyCharToChar,
386    CastPgLegacyCharToVarChar,
387    CastPgLegacyCharToInt32,
388    CastBytesToString,
389    CastStringToJsonb,
390    CastJsonbToString,
391    CastJsonbableToJsonb,
392    CastJsonbToInt16,
393    CastJsonbToInt32,
394    CastJsonbToInt64,
395    CastJsonbToFloat32,
396    CastJsonbToFloat64,
397    CastJsonbToNumeric,
398    CastJsonbToBool,
399    CastUuidToString,
400    CastRecordToString,
401    CastRecord1ToRecord2,
402    CastArrayToArray,
403    CastArrayToJsonb,
404    CastArrayToString,
405    CastListToString,
406    CastListToJsonb,
407    CastList1ToList2,
408    CastArrayToListOneDim,
409    CastMapToString,
410    CastInt2VectorToString,
411    CastRangeToString,
412    CeilFloat32,
413    CeilFloat64,
414    CeilNumeric,
415    FloorFloat32,
416    FloorFloat64,
417    FloorNumeric,
418    Ascii,
419    BitCountBytes,
420    BitLengthBytes,
421    BitLengthString,
422    ByteLengthBytes,
423    ByteLengthString,
424    CharLength,
425    Chr,
426    IsLikeMatch,
427    IsRegexpMatch,
428    RegexpMatch,
429    ExtractInterval,
430    ExtractTime,
431    ExtractTimestamp,
432    ExtractTimestampTz,
433    ExtractDate,
434    DatePartInterval,
435    DatePartTime,
436    DatePartTimestamp,
437    DatePartTimestampTz,
438    DateTruncTimestamp,
439    DateTruncTimestampTz,
440    TimezoneTimestamp,
441    TimezoneTimestampTz,
442    TimezoneTime,
443    ToTimestamp,
444    ToCharTimestamp,
445    ToCharTimestampTz,
446    JustifyDays,
447    JustifyHours,
448    JustifyInterval,
449    JsonbArrayLength,
450    JsonbTypeof,
451    JsonbStripNulls,
452    JsonbPretty,
453    RoundFloat32,
454    RoundFloat64,
455    RoundNumeric,
456    TruncFloat32,
457    TruncFloat64,
458    TruncNumeric,
459    TrimWhitespace,
460    TrimLeadingWhitespace,
461    TrimTrailingWhitespace,
462    Initcap,
463    RecordGet,
464    ListLength,
465    MapLength,
466    MapBuildFromRecordList,
467    Upper,
468    Lower,
469    Cos,
470    Acos,
471    Cosh,
472    Acosh,
473    Sin,
474    Asin,
475    Sinh,
476    Asinh,
477    Tan,
478    Atan,
479    Tanh,
480    Atanh,
481    Cot,
482    Degrees,
483    Radians,
484    Log10,
485    Log10Numeric,
486    Ln,
487    LnNumeric,
488    Exp,
489    ExpNumeric,
490    Sleep,
491    Panic,
492    AdjustNumericScale,
493    PgColumnSize,
494    MzRowSize,
495    MzTypeName,
496    StepMzTimestamp,
497    RangeLower,
498    RangeUpper,
499    RangeEmpty,
500    RangeLowerInc,
501    RangeUpperInc,
502    RangeLowerInf,
503    RangeUpperInf,
504    MzAclItemGrantor,
505    MzAclItemGrantee,
506    MzAclItemPrivileges,
507    MzFormatPrivileges,
508    MzValidatePrivileges,
509    MzValidateRolePrivilege,
510    AclItemGrantor,
511    AclItemGrantee,
512    AclItemPrivileges,
513    QuoteIdent,
514    TryParseMonotonicIso8601Timestamp,
515    RegexpSplitToArray,
516    PgSizePretty,
517    Crc32Bytes,
518    Crc32String,
519    KafkaMurmur2Bytes,
520    KafkaMurmur2String,
521    SeahashBytes,
522    SeahashString,
523    Reverse
524);
525
526impl UnaryFunc {
527    /// If the unary_func represents "IS X", return X.
528    ///
529    /// A helper method for being able to print Not(IsX) as IS NOT X.
530    pub fn is(&self) -> Option<&'static str> {
531        match self {
532            UnaryFunc::IsNull(_) => Some("NULL"),
533            UnaryFunc::IsTrue(_) => Some("TRUE"),
534            UnaryFunc::IsFalse(_) => Some("FALSE"),
535            _ => None,
536        }
537    }
538}
539
540#[cfg(test)]
541mod test {
542    use itertools::Itertools;
543    use mz_repr::{PropDatum, SqlScalarType};
544
545    use crate::like_pattern;
546
547    use super::*;
548
549    #[mz_ore::test]
550    fn test_could_error() {
551        for func in [
552            UnaryFunc::IsNull(IsNull),
553            UnaryFunc::CastVarCharToString(CastVarCharToString),
554            UnaryFunc::Not(Not),
555            UnaryFunc::IsLikeMatch(IsLikeMatch(like_pattern::compile("%hi%", false).unwrap())),
556        ] {
557            assert!(!func.could_error())
558        }
559    }
560
561    #[mz_ore::test]
562    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
563    fn test_is_monotone() {
564        use proptest::prelude::*;
565
566        /// Asserts that the function is either monotonically increasing or decreasing over
567        /// the given sets of arguments.
568        fn assert_monotone<'a, const N: usize>(
569            expr: &MirScalarExpr,
570            arena: &'a RowArena,
571            datums: &[[Datum<'a>; N]],
572        ) {
573            // TODO: assertions for nulls, errors
574            let Ok(results) = datums
575                .iter()
576                .map(|args| expr.eval(args.as_slice(), arena))
577                .collect::<Result<Vec<_>, _>>()
578            else {
579                return;
580            };
581
582            let forward = results.iter().tuple_windows().all(|(a, b)| a <= b);
583            let reverse = results.iter().tuple_windows().all(|(a, b)| a >= b);
584            assert!(
585                forward || reverse,
586                "expected {expr} to be monotone, but passing {datums:?} returned {results:?}"
587            );
588        }
589
590        fn proptest_unary<'a>(
591            func: UnaryFunc,
592            arena: &'a RowArena,
593            arg: impl Strategy<Value = PropDatum>,
594        ) {
595            let is_monotone = func.is_monotone();
596            let expr = MirScalarExpr::CallUnary {
597                func,
598                expr: Box::new(MirScalarExpr::column(0)),
599            };
600            if is_monotone {
601                proptest!(|(
602                    mut arg in proptest::array::uniform3(arg),
603                )| {
604                    arg.sort();
605                    let args: Vec<_> = arg.iter().map(|a| [Datum::from(a)]).collect();
606                    assert_monotone(&expr, arena, &args);
607                });
608            }
609        }
610
611        let interesting_i32s: Vec<Datum<'static>> =
612            SqlScalarType::Int32.interesting_datums().collect();
613        let i32_datums = proptest::strategy::Union::new([
614            any::<i32>().prop_map(PropDatum::Int32).boxed(),
615            (0..interesting_i32s.len())
616                .prop_map(move |i| {
617                    let Datum::Int32(val) = interesting_i32s[i] else {
618                        unreachable!("interesting int32 has non-i32s")
619                    };
620                    PropDatum::Int32(val)
621                })
622                .boxed(),
623            (-10i32..10).prop_map(PropDatum::Int32).boxed(),
624        ]);
625
626        let arena = RowArena::new();
627
628        // It would be interesting to test all funcs here, but we currently need to hardcode
629        // the generators for the argument types, which makes this tedious. Choose an interesting
630        // subset for now.
631        proptest_unary(
632            UnaryFunc::CastInt32ToNumeric(CastInt32ToNumeric(None)),
633            &arena,
634            &i32_datums,
635        );
636        proptest_unary(
637            UnaryFunc::CastInt32ToUint16(CastInt32ToUint16),
638            &arena,
639            &i32_datums,
640        );
641        proptest_unary(
642            UnaryFunc::CastInt32ToString(CastInt32ToString),
643            &arena,
644            &i32_datums,
645        );
646    }
647}