Skip to main content

mz_expr/scalar/func/
macros.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Convenience macro for generating `inverse` values.
11macro_rules! to_unary {
12    ($f:expr) => {
13        Some(crate::UnaryFunc::from($f))
14    };
15}
16
17#[cfg(test)]
18mod test {
19    use mz_expr_derive::sqlfunc;
20    use mz_repr::SqlScalarType;
21
22    use crate::EvalError;
23    use crate::scalar::func::LazyUnaryFunc;
24
25    #[sqlfunc(sqlname = "INFALLIBLE")]
26    fn infallible1(a: f32) -> f32 {
27        a
28    }
29
30    #[sqlfunc]
31    fn infallible2(a: Option<f32>) -> f32 {
32        a.unwrap_or_default()
33    }
34
35    #[sqlfunc]
36    fn infallible3(a: f32) -> Option<f32> {
37        Some(a)
38    }
39
40    #[mz_ore::test]
41    fn elision_rules_infallible() {
42        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
43        assert!(Infallible1.propagates_nulls());
44        assert!(!Infallible1.introduces_nulls());
45
46        assert!(!Infallible2.propagates_nulls());
47        assert!(!Infallible2.introduces_nulls());
48
49        assert!(Infallible3.propagates_nulls());
50        assert!(Infallible3.introduces_nulls());
51    }
52
53    #[mz_ore::test]
54    fn output_types_infallible() {
55        assert_eq!(
56            Infallible1.output_sql_type(SqlScalarType::Float32.nullable(true)),
57            SqlScalarType::Float32.nullable(true)
58        );
59        assert_eq!(
60            Infallible1.output_sql_type(SqlScalarType::Float32.nullable(false)),
61            SqlScalarType::Float32.nullable(false)
62        );
63
64        assert_eq!(
65            Infallible2.output_sql_type(SqlScalarType::Float32.nullable(true)),
66            SqlScalarType::Float32.nullable(false)
67        );
68        assert_eq!(
69            Infallible2.output_sql_type(SqlScalarType::Float32.nullable(false)),
70            SqlScalarType::Float32.nullable(false)
71        );
72
73        assert_eq!(
74            Infallible3.output_sql_type(SqlScalarType::Float32.nullable(true)),
75            SqlScalarType::Float32.nullable(true)
76        );
77        assert_eq!(
78            Infallible3.output_sql_type(SqlScalarType::Float32.nullable(false)),
79            SqlScalarType::Float32.nullable(true)
80        );
81    }
82
83    #[sqlfunc]
84    fn fallible1(a: f32) -> Result<f32, EvalError> {
85        Ok(a)
86    }
87
88    #[sqlfunc]
89    fn fallible2(a: Option<f32>) -> Result<f32, EvalError> {
90        Ok(a.unwrap_or_default())
91    }
92
93    #[sqlfunc]
94    fn fallible3(a: f32) -> Result<Option<f32>, EvalError> {
95        Ok(Some(a))
96    }
97
98    #[mz_ore::test]
99    fn elision_rules_fallible() {
100        assert!(Fallible1.propagates_nulls());
101        assert!(!Fallible1.introduces_nulls());
102
103        assert!(!Fallible2.propagates_nulls());
104        assert!(!Fallible2.introduces_nulls());
105
106        assert!(Fallible3.propagates_nulls());
107        assert!(Fallible3.introduces_nulls());
108    }
109
110    #[mz_ore::test]
111    fn output_types_fallible() {
112        assert_eq!(
113            Fallible1.output_sql_type(SqlScalarType::Float32.nullable(true)),
114            SqlScalarType::Float32.nullable(true)
115        );
116        assert_eq!(
117            Fallible1.output_sql_type(SqlScalarType::Float32.nullable(false)),
118            SqlScalarType::Float32.nullable(false)
119        );
120
121        assert_eq!(
122            Fallible2.output_sql_type(SqlScalarType::Float32.nullable(true)),
123            SqlScalarType::Float32.nullable(false)
124        );
125        assert_eq!(
126            Fallible2.output_sql_type(SqlScalarType::Float32.nullable(false)),
127            SqlScalarType::Float32.nullable(false)
128        );
129
130        assert_eq!(
131            Fallible3.output_sql_type(SqlScalarType::Float32.nullable(true)),
132            SqlScalarType::Float32.nullable(true)
133        );
134        assert_eq!(
135            Fallible3.output_sql_type(SqlScalarType::Float32.nullable(false)),
136            SqlScalarType::Float32.nullable(true)
137        );
138    }
139}
140
141/// Temporary macro that generates the equivalent of what enum_dispatch will do in the future. We
142/// need this manual macro implementation to delegate to the previous manual implementation for
143/// variants that use the old definitions.
144///
145/// Once everything is handled by this macro we can remove it and replace it with `enum_dispatch`
146macro_rules! derive_unary {
147    ($($name:ident),*) => {
148        #[derive(
149            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
150            serde::Serialize, serde::Deserialize, Hash,
151            mz_lowertest::MzReflect,
152        )]
153        pub enum UnaryFunc {
154            $($name($name),)*
155        }
156
157        impl UnaryFunc {
158            pub fn eval<'a>(
159                &'a self,
160                datums: &[Datum<'a>],
161                temp_storage: &'a RowArena,
162                a: &'a MirScalarExpr,
163            ) -> Result<Datum<'a>, EvalError> {
164                match self {
165                    $(Self::$name(f) => f.eval(datums, temp_storage, a),)*
166                }
167            }
168
169            pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
170                match self {
171                    $(Self::$name(f) => LazyUnaryFunc::output_sql_type(f, input_type),)*
172                }
173            }
174            pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
175                match self {
176                    $(Self::$name(f) => LazyUnaryFunc::output_type(f, input_type),)*
177                }
178            }
179            pub fn propagates_nulls(&self) -> bool {
180                match self {
181                    $(Self::$name(f) => LazyUnaryFunc::propagates_nulls(f),)*
182                }
183            }
184            pub fn introduces_nulls(&self) -> bool {
185                match self {
186                    $(Self::$name(f) => LazyUnaryFunc::introduces_nulls(f),)*
187                }
188            }
189            pub fn preserves_uniqueness(&self) -> bool {
190                match self {
191                    $(Self::$name(f) => LazyUnaryFunc::preserves_uniqueness(f),)*
192                }
193            }
194            pub fn inverse(&self) -> Option<UnaryFunc> {
195                match self {
196                    $(Self::$name(f) => LazyUnaryFunc::inverse(f),)*
197                }
198            }
199            pub fn is_monotone(&self) -> bool {
200                match self {
201                    $(Self::$name(f) => LazyUnaryFunc::is_monotone(f),)*
202                }
203            }
204            pub fn could_error(&self) -> bool {
205                match self {
206                    $(Self::$name(f) => LazyUnaryFunc::could_error(f),)*
207                }
208            }
209        }
210
211        impl fmt::Display for UnaryFunc {
212            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213                match self {
214                    $(Self::$name(func) => func.fmt(f),)*
215                }
216            }
217        }
218
219        $(
220            impl From<$name> for crate::UnaryFunc {
221                fn from(variant: $name) -> Self {
222                    Self::$name(variant)
223                }
224            }
225        )*
226    }
227}
228
229/// Generates the `VariadicFunc` enum, its `impl` block,
230/// `Display` impl, and `From<InnerType>` impls for each variant.
231///
232/// All variants must use explicit `Name(Type)` syntax. When the variant name equals
233/// the inner type name, write e.g. `And(And)`.
234macro_rules! derive_variadic {
235    ($($name:ident ( $variant:ident )),* $(,)?) => {
236        #[derive(
237            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
238            serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect,
239        )]
240        pub enum VariadicFunc {
241            $($name($variant),)*
242        }
243
244        impl VariadicFunc {
245            pub fn eval<'a>(
246                &'a self,
247                datums: &[Datum<'a>],
248                temp_storage: &'a RowArena,
249                exprs: &'a [MirScalarExpr],
250            ) -> Result<Datum<'a>, EvalError> {
251                match self {
252                    $(Self::$name(f) => f.eval(datums, temp_storage, exprs),)*
253                }
254            }
255
256            pub fn output_sql_type(&self, input_types: Vec<SqlColumnType>) -> SqlColumnType {
257                match self {
258                    $(Self::$name(f) => LazyVariadicFunc::output_type(f, &input_types),)*
259                }
260            }
261
262            /// Computes the representation type of this variadic function.
263            ///
264            /// Wrapper around [`Self::output_sql_type`] that converts to representation types.
265            pub fn output_type(&self, input_types: Vec<ReprColumnType>) -> ReprColumnType {
266                let sql_types = input_types.iter().map(SqlColumnType::from_repr).collect();
267                ReprColumnType::from(&self.output_sql_type(sql_types))
268            }
269
270            pub fn propagates_nulls(&self) -> bool {
271                match self {
272                    $(Self::$name(f) => LazyVariadicFunc::propagates_nulls(f),)*
273                }
274            }
275
276            pub fn introduces_nulls(&self) -> bool {
277                match self {
278                    $(Self::$name(f) => LazyVariadicFunc::introduces_nulls(f),)*
279                }
280            }
281
282            pub fn could_error(&self) -> bool {
283                match self {
284                    $(Self::$name(f) => LazyVariadicFunc::could_error(f),)*
285                }
286            }
287
288            pub fn is_monotone(&self) -> bool {
289                match self {
290                    $(Self::$name(f) => LazyVariadicFunc::is_monotone(f),)*
291                }
292            }
293
294            pub fn is_associative(&self) -> bool {
295                match self {
296                    $(Self::$name(f) => LazyVariadicFunc::is_associative(f),)*
297                }
298            }
299
300            pub fn is_infix_op(&self) -> bool {
301                match self {
302                    $(Self::$name(f) => LazyVariadicFunc::is_infix_op(f),)*
303                }
304            }
305        }
306
307        impl fmt::Display for VariadicFunc {
308            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
309                match self {
310                    $(Self::$name(func) => func.fmt(f),)*
311                }
312            }
313        }
314
315        $(
316            impl From<$variant> for crate::VariadicFunc {
317                fn from(variant: $variant) -> Self {
318                    Self::$name(variant)
319                }
320            }
321        )*
322    }
323}
324
325/// Generates the `BinaryFunc` enum, its `impl` block (with 8 delegating methods),
326/// `Display` impl, and `From<InnerType>` impls for each variant.
327///
328/// All variants must use explicit `Name(Type)` syntax. When the variant name equals
329/// the inner type name, write e.g. `AddInt16(AddInt16)`.
330macro_rules! derive_binary {
331    ($($name:ident ( $variant:ident )),* $(,)?) => {
332        #[derive(
333            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
334            serde::Serialize, serde::Deserialize, Hash,
335            mz_lowertest::MzReflect,
336        )]
337        pub enum BinaryFunc {
338            $($name($variant),)*
339        }
340
341        impl BinaryFunc {
342            pub fn eval<'a>(
343                &'a self,
344                datums: &[Datum<'a>],
345                temp_storage: &'a RowArena,
346                exprs: &[&'a MirScalarExpr],
347            ) -> Result<Datum<'a>, EvalError> {
348                match self {
349                    $(Self::$name(f) => f.eval(datums, temp_storage, exprs),)*
350                }
351            }
352
353            pub fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
354                match self {
355                    $(Self::$name(f) => {
356                        LazyBinaryFunc::output_sql_type(f, input_types)
357                    },)*
358                }
359            }
360
361            pub fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
362                match self {
363                    $(Self::$name(f) => LazyBinaryFunc::output_type(f, input_types),)*
364                }
365            }
366
367            pub fn propagates_nulls(&self) -> bool {
368                match self {
369                    $(Self::$name(f) => LazyBinaryFunc::propagates_nulls(f),)*
370                }
371            }
372
373            pub fn introduces_nulls(&self) -> bool {
374                match self {
375                    $(Self::$name(f) => LazyBinaryFunc::introduces_nulls(f),)*
376                }
377            }
378
379            pub fn is_infix_op(&self) -> bool {
380                match self {
381                    $(Self::$name(f) => LazyBinaryFunc::is_infix_op(f),)*
382                }
383            }
384
385            pub fn negate(&self) -> Option<BinaryFunc> {
386                match self {
387                    $(Self::$name(f) => LazyBinaryFunc::negate(f),)*
388                }
389            }
390
391            pub fn could_error(&self) -> bool {
392                match self {
393                    $(Self::$name(f) => LazyBinaryFunc::could_error(f),)*
394                }
395            }
396
397            pub fn is_monotone(&self) -> (bool, bool) {
398                match self {
399                    $(Self::$name(f) => LazyBinaryFunc::is_monotone(f),)*
400                }
401            }
402        }
403
404        impl fmt::Display for BinaryFunc {
405            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
406                match self {
407                    $(Self::$name(func) => func.fmt(f),)*
408                }
409            }
410        }
411
412        $(
413            impl From<$variant> for crate::BinaryFunc {
414                fn from(variant: $variant) -> Self {
415                    Self::$name(variant)
416                }
417            }
418        )*
419    }
420}