mz_expr/scalar/func/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for binary functions.
11
12use mz_repr::{ColumnType, Datum, DatumType, RowArena};
13
14use crate::{EvalError, MirScalarExpr};
15
16/// A description of an SQL binary function that has the ability to lazy evaluate its arguments
17// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
18#[allow(unused)]
19pub(crate) trait LazyBinaryFunc {
20    fn eval<'a>(
21        &'a self,
22        datums: &[Datum<'a>],
23        temp_storage: &'a RowArena,
24        a: &'a MirScalarExpr,
25        b: &'a MirScalarExpr,
26    ) -> Result<Datum<'a>, EvalError>;
27
28    /// The output ColumnType of this function.
29    fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType;
30
31    /// Whether this function will produce NULL on NULL input.
32    fn propagates_nulls(&self) -> bool;
33
34    /// Whether this function will produce NULL on non-NULL input.
35    fn introduces_nulls(&self) -> bool;
36
37    /// Whether this function might error on non-error input.
38    fn could_error(&self) -> bool {
39        // NB: override this for functions that never error.
40        true
41    }
42
43    /// Returns the negation of the function, if one exists.
44    fn negate(&self) -> Option<crate::BinaryFunc>;
45
46    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
47    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
48    /// determine the range of possible outputs just by mapping the endpoints.
49    ///
50    /// This describes the *pointwise* behaviour of the function:
51    /// ie. the behaviour of any specific argument as the others are held constant. (For example, `a - b` is
52    /// monotone in the first argument because for any particular value of `b`, increasing `a` will
53    /// always cause the result to increase... and in the second argument because for any specific `a`,
54    /// increasing `b` will always cause the result to _decrease_.)
55    ///
56    /// This property describes the behaviour of the function over ranges where the function is defined:
57    /// ie. the arguments and the result are non-error datums.
58    fn is_monotone(&self) -> (bool, bool);
59
60    /// Yep, I guess this returns true for infix operators.
61    fn is_infix_op(&self) -> bool;
62}
63
64#[allow(unused)]
65pub(crate) trait EagerBinaryFunc<'a> {
66    type Input1: DatumType<'a, EvalError>;
67    type Input2: DatumType<'a, EvalError>;
68    type Output: DatumType<'a, EvalError>;
69
70    fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a RowArena) -> Self::Output;
71
72    /// The output ColumnType of this function
73    fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType;
74
75    /// Whether this function will produce NULL on NULL input
76    fn propagates_nulls(&self) -> bool {
77        // If the inputs are not nullable then nulls are propagated
78        !Self::Input1::nullable() && !Self::Input2::nullable()
79    }
80
81    /// Whether this function will produce NULL on non-NULL input
82    fn introduces_nulls(&self) -> bool {
83        // If the output is nullable then nulls can be introduced
84        Self::Output::nullable()
85    }
86
87    /// Whether this function could produce an error
88    fn could_error(&self) -> bool {
89        Self::Output::fallible()
90    }
91
92    /// Returns the negation of the given binary function, if it exists.
93    fn negate(&self) -> Option<crate::BinaryFunc> {
94        None
95    }
96
97    fn is_monotone(&self) -> (bool, bool) {
98        (false, false)
99    }
100
101    fn is_infix_op(&self) -> bool {
102        false
103    }
104}
105
106impl<T: for<'a> EagerBinaryFunc<'a>> LazyBinaryFunc for T {
107    fn eval<'a>(
108        &'a self,
109        datums: &[Datum<'a>],
110        temp_storage: &'a RowArena,
111        a: &'a MirScalarExpr,
112        b: &'a MirScalarExpr,
113    ) -> Result<Datum<'a>, EvalError> {
114        let a = match T::Input1::try_from_result(a.eval(datums, temp_storage)) {
115            // If we can convert to the input type then we call the function
116            Ok(input) => input,
117            // If we can't and we got a non-null datum something went wrong in the planner
118            Err(Ok(datum)) if !datum.is_null() => {
119                return Err(EvalError::Internal("invalid input type".into()));
120            }
121            // Otherwise we just propagate NULLs and errors
122            Err(res) => return res,
123        };
124        let b = match T::Input2::try_from_result(b.eval(datums, temp_storage)) {
125            // If we can convert to the input type then we call the function
126            Ok(input) => input,
127            // If we can't and we got a non-null datum something went wrong in the planner
128            Err(Ok(datum)) if !datum.is_null() => {
129                return Err(EvalError::Internal("invalid input type".into()));
130            }
131            // Otherwise we just propagate NULLs and errors
132            Err(res) => return res,
133        };
134        self.call(a, b, temp_storage).into_result(temp_storage)
135    }
136
137    fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType {
138        self.output_type(input_type_a, input_type_b)
139    }
140
141    fn propagates_nulls(&self) -> bool {
142        self.propagates_nulls()
143    }
144
145    fn introduces_nulls(&self) -> bool {
146        self.introduces_nulls()
147    }
148
149    fn could_error(&self) -> bool {
150        self.could_error()
151    }
152
153    fn negate(&self) -> Option<crate::BinaryFunc> {
154        self.negate()
155    }
156
157    fn is_monotone(&self) -> (bool, bool) {
158        self.is_monotone()
159    }
160
161    fn is_infix_op(&self) -> bool {
162        self.is_infix_op()
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use mz_expr_derive::sqlfunc;
169    use mz_repr::ColumnType;
170    use mz_repr::ScalarType;
171
172    use crate::scalar::func::binary::LazyBinaryFunc;
173    use crate::{BinaryFunc, EvalError, func};
174
175    #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true)]
176    fn infallible1(a: f32, b: f32) -> f32 {
177        a + b
178    }
179
180    #[sqlfunc]
181    fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
182        a.unwrap_or_default() + b.unwrap_or_default()
183    }
184
185    #[sqlfunc]
186    fn infallible3(a: f32, b: f32) -> Option<f32> {
187        Some(a + b)
188    }
189
190    #[mz_ore::test]
191    fn elision_rules_infallible() {
192        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
193        assert!(Infallible1.propagates_nulls());
194        assert!(!Infallible1.introduces_nulls());
195
196        assert!(!Infallible2.propagates_nulls());
197        assert!(!Infallible2.introduces_nulls());
198
199        assert!(Infallible3.propagates_nulls());
200        assert!(Infallible3.introduces_nulls());
201    }
202
203    #[mz_ore::test]
204    fn output_types_infallible() {
205        assert_eq!(
206            Infallible1.output_type(
207                ScalarType::Float32.nullable(true),
208                ScalarType::Float32.nullable(true)
209            ),
210            ScalarType::Float32.nullable(true)
211        );
212        assert_eq!(
213            Infallible1.output_type(
214                ScalarType::Float32.nullable(true),
215                ScalarType::Float32.nullable(false)
216            ),
217            ScalarType::Float32.nullable(true)
218        );
219        assert_eq!(
220            Infallible1.output_type(
221                ScalarType::Float32.nullable(false),
222                ScalarType::Float32.nullable(true)
223            ),
224            ScalarType::Float32.nullable(true)
225        );
226        assert_eq!(
227            Infallible1.output_type(
228                ScalarType::Float32.nullable(false),
229                ScalarType::Float32.nullable(false)
230            ),
231            ScalarType::Float32.nullable(false)
232        );
233
234        assert_eq!(
235            Infallible2.output_type(
236                ScalarType::Float32.nullable(true),
237                ScalarType::Float32.nullable(true)
238            ),
239            ScalarType::Float32.nullable(false)
240        );
241        assert_eq!(
242            Infallible2.output_type(
243                ScalarType::Float32.nullable(true),
244                ScalarType::Float32.nullable(false)
245            ),
246            ScalarType::Float32.nullable(false)
247        );
248        assert_eq!(
249            Infallible2.output_type(
250                ScalarType::Float32.nullable(false),
251                ScalarType::Float32.nullable(true)
252            ),
253            ScalarType::Float32.nullable(false)
254        );
255        assert_eq!(
256            Infallible2.output_type(
257                ScalarType::Float32.nullable(false),
258                ScalarType::Float32.nullable(false)
259            ),
260            ScalarType::Float32.nullable(false)
261        );
262
263        assert_eq!(
264            Infallible3.output_type(
265                ScalarType::Float32.nullable(true),
266                ScalarType::Float32.nullable(true)
267            ),
268            ScalarType::Float32.nullable(true)
269        );
270        assert_eq!(
271            Infallible3.output_type(
272                ScalarType::Float32.nullable(true),
273                ScalarType::Float32.nullable(false)
274            ),
275            ScalarType::Float32.nullable(true)
276        );
277        assert_eq!(
278            Infallible3.output_type(
279                ScalarType::Float32.nullable(false),
280                ScalarType::Float32.nullable(true)
281            ),
282            ScalarType::Float32.nullable(true)
283        );
284        assert_eq!(
285            Infallible3.output_type(
286                ScalarType::Float32.nullable(false),
287                ScalarType::Float32.nullable(false)
288            ),
289            ScalarType::Float32.nullable(true)
290        );
291    }
292
293    #[sqlfunc]
294    fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
295        Ok(a + b)
296    }
297
298    #[sqlfunc]
299    fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
300        Ok(a.unwrap_or_default() + b.unwrap_or_default())
301    }
302
303    #[sqlfunc]
304    fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
305        Ok(Some(a + b))
306    }
307
308    #[mz_ore::test]
309    fn elision_rules_fallible() {
310        assert!(Fallible1.propagates_nulls());
311        assert!(!Fallible1.introduces_nulls());
312
313        assert!(!Fallible2.propagates_nulls());
314        assert!(!Fallible2.introduces_nulls());
315
316        assert!(Fallible3.propagates_nulls());
317        assert!(Fallible3.introduces_nulls());
318    }
319
320    #[mz_ore::test]
321    fn output_types_fallible() {
322        assert_eq!(
323            Fallible1.output_type(
324                ScalarType::Float32.nullable(true),
325                ScalarType::Float32.nullable(true)
326            ),
327            ScalarType::Float32.nullable(true)
328        );
329        assert_eq!(
330            Fallible1.output_type(
331                ScalarType::Float32.nullable(true),
332                ScalarType::Float32.nullable(false)
333            ),
334            ScalarType::Float32.nullable(true)
335        );
336        assert_eq!(
337            Fallible1.output_type(
338                ScalarType::Float32.nullable(false),
339                ScalarType::Float32.nullable(true)
340            ),
341            ScalarType::Float32.nullable(true)
342        );
343        assert_eq!(
344            Fallible1.output_type(
345                ScalarType::Float32.nullable(false),
346                ScalarType::Float32.nullable(false)
347            ),
348            ScalarType::Float32.nullable(false)
349        );
350
351        assert_eq!(
352            Fallible2.output_type(
353                ScalarType::Float32.nullable(true),
354                ScalarType::Float32.nullable(true)
355            ),
356            ScalarType::Float32.nullable(false)
357        );
358        assert_eq!(
359            Fallible2.output_type(
360                ScalarType::Float32.nullable(true),
361                ScalarType::Float32.nullable(false)
362            ),
363            ScalarType::Float32.nullable(false)
364        );
365        assert_eq!(
366            Fallible2.output_type(
367                ScalarType::Float32.nullable(false),
368                ScalarType::Float32.nullable(true)
369            ),
370            ScalarType::Float32.nullable(false)
371        );
372        assert_eq!(
373            Fallible2.output_type(
374                ScalarType::Float32.nullable(false),
375                ScalarType::Float32.nullable(false)
376            ),
377            ScalarType::Float32.nullable(false)
378        );
379
380        assert_eq!(
381            Fallible3.output_type(
382                ScalarType::Float32.nullable(true),
383                ScalarType::Float32.nullable(true)
384            ),
385            ScalarType::Float32.nullable(true)
386        );
387        assert_eq!(
388            Fallible3.output_type(
389                ScalarType::Float32.nullable(true),
390                ScalarType::Float32.nullable(false)
391            ),
392            ScalarType::Float32.nullable(true)
393        );
394        assert_eq!(
395            Fallible3.output_type(
396                ScalarType::Float32.nullable(false),
397                ScalarType::Float32.nullable(true)
398            ),
399            ScalarType::Float32.nullable(true)
400        );
401        assert_eq!(
402            Fallible3.output_type(
403                ScalarType::Float32.nullable(false),
404                ScalarType::Float32.nullable(false)
405            ),
406            ScalarType::Float32.nullable(true)
407        );
408    }
409
410    #[mz_ore::test]
411    fn test_equivalence() {
412        #[track_caller]
413        fn check<T: LazyBinaryFunc + std::fmt::Display>(
414            new: T,
415            old: BinaryFunc,
416            column_ty: ColumnType,
417        ) {
418            assert_eq!(
419                new.propagates_nulls(),
420                old.propagates_nulls(),
421                "propagates_nulls mismatch"
422            );
423            assert_eq!(
424                new.introduces_nulls(),
425                old.introduces_nulls(),
426                "introduces_nulls mismatch"
427            );
428            assert_eq!(new.could_error(), old.could_error(), "could_error mismatch");
429            assert_eq!(new.is_monotone(), old.is_monotone(), "is_monotone mismatch");
430            assert_eq!(new.is_infix_op(), old.is_infix_op(), "is_infix_op mismatch");
431            assert_eq!(
432                new.output_type(column_ty.clone(), column_ty.clone()),
433                old.output_type(column_ty.clone(), column_ty.clone()),
434                "output_type mismatch"
435            );
436            assert_eq!(format!("{}", new), format!("{}", old), "format mismatch");
437        }
438        let i32_ty = ColumnType {
439            nullable: true,
440            scalar_type: ScalarType::Int32,
441        };
442        let ts_tz_ty = ColumnType {
443            nullable: true,
444            scalar_type: ScalarType::TimestampTz { precision: None },
445        };
446
447        use BinaryFunc as BF;
448
449        check(func::AddInt16, BF::AddInt16, i32_ty.clone());
450        check(func::AddInt32, BF::AddInt32, i32_ty.clone());
451        check(func::AddInt64, BF::AddInt64, i32_ty.clone());
452        check(func::AddUint16, BF::AddUInt16, i32_ty.clone());
453        check(func::AddUint32, BF::AddUInt32, i32_ty.clone());
454        check(func::AddUint64, BF::AddUInt64, i32_ty.clone());
455        check(func::AddFloat32, BF::AddFloat32, i32_ty.clone());
456        check(func::AddFloat64, BF::AddFloat64, i32_ty.clone());
457        check(func::AddDateTime, BF::AddDateTime, i32_ty.clone());
458        check(func::AddDateInterval, BF::AddDateInterval, i32_ty.clone());
459        check(func::AddTimeInterval, BF::AddTimeInterval, ts_tz_ty.clone());
460        check(func::RoundNumericBinary, BF::RoundNumeric, i32_ty.clone());
461    }
462}