Skip to main content

mz_expr/scalar/func/
macros.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Convenience macro for generating `inverse` values.
11macro_rules! to_unary {
12    ($f:expr) => {
13        Some(crate::UnaryFunc::from($f))
14    };
15}
16
17#[cfg(test)]
18mod test {
19    use mz_expr_derive::sqlfunc;
20    use mz_repr::SqlScalarType;
21
22    use crate::EvalError;
23    use crate::scalar::func::LazyUnaryFunc;
24
25    #[sqlfunc(sqlname = "INFALLIBLE")]
26    fn infallible1(a: f32) -> f32 {
27        a
28    }
29
30    #[sqlfunc]
31    fn infallible2(a: Option<f32>) -> f32 {
32        a.unwrap_or_default()
33    }
34
35    #[sqlfunc]
36    fn infallible3(a: f32) -> Option<f32> {
37        Some(a)
38    }
39
40    #[mz_ore::test]
41    fn elision_rules_infallible() {
42        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
43        assert!(Infallible1.propagates_nulls());
44        assert!(!Infallible1.introduces_nulls());
45
46        assert!(!Infallible2.propagates_nulls());
47        assert!(!Infallible2.introduces_nulls());
48
49        assert!(Infallible3.propagates_nulls());
50        assert!(Infallible3.introduces_nulls());
51    }
52
53    #[mz_ore::test]
54    fn output_types_infallible() {
55        assert_eq!(
56            Infallible1.output_sql_type(SqlScalarType::Float32.nullable(true)),
57            SqlScalarType::Float32.nullable(true)
58        );
59        assert_eq!(
60            Infallible1.output_sql_type(SqlScalarType::Float32.nullable(false)),
61            SqlScalarType::Float32.nullable(false)
62        );
63
64        assert_eq!(
65            Infallible2.output_sql_type(SqlScalarType::Float32.nullable(true)),
66            SqlScalarType::Float32.nullable(false)
67        );
68        assert_eq!(
69            Infallible2.output_sql_type(SqlScalarType::Float32.nullable(false)),
70            SqlScalarType::Float32.nullable(false)
71        );
72
73        assert_eq!(
74            Infallible3.output_sql_type(SqlScalarType::Float32.nullable(true)),
75            SqlScalarType::Float32.nullable(true)
76        );
77        assert_eq!(
78            Infallible3.output_sql_type(SqlScalarType::Float32.nullable(false)),
79            SqlScalarType::Float32.nullable(true)
80        );
81    }
82
83    #[sqlfunc]
84    fn fallible1(a: f32) -> Result<f32, EvalError> {
85        Ok(a)
86    }
87
88    #[sqlfunc]
89    fn fallible2(a: Option<f32>) -> Result<f32, EvalError> {
90        Ok(a.unwrap_or_default())
91    }
92
93    #[sqlfunc]
94    fn fallible3(a: f32) -> Result<Option<f32>, EvalError> {
95        Ok(Some(a))
96    }
97
98    #[mz_ore::test]
99    fn elision_rules_fallible() {
100        assert!(Fallible1.propagates_nulls());
101        assert!(!Fallible1.introduces_nulls());
102
103        assert!(!Fallible2.propagates_nulls());
104        assert!(!Fallible2.introduces_nulls());
105
106        assert!(Fallible3.propagates_nulls());
107        assert!(Fallible3.introduces_nulls());
108    }
109
110    #[mz_ore::test]
111    fn output_types_fallible() {
112        assert_eq!(
113            Fallible1.output_sql_type(SqlScalarType::Float32.nullable(true)),
114            SqlScalarType::Float32.nullable(true)
115        );
116        assert_eq!(
117            Fallible1.output_sql_type(SqlScalarType::Float32.nullable(false)),
118            SqlScalarType::Float32.nullable(false)
119        );
120
121        assert_eq!(
122            Fallible2.output_sql_type(SqlScalarType::Float32.nullable(true)),
123            SqlScalarType::Float32.nullable(false)
124        );
125        assert_eq!(
126            Fallible2.output_sql_type(SqlScalarType::Float32.nullable(false)),
127            SqlScalarType::Float32.nullable(false)
128        );
129
130        assert_eq!(
131            Fallible3.output_sql_type(SqlScalarType::Float32.nullable(true)),
132            SqlScalarType::Float32.nullable(true)
133        );
134        assert_eq!(
135            Fallible3.output_sql_type(SqlScalarType::Float32.nullable(false)),
136            SqlScalarType::Float32.nullable(true)
137        );
138    }
139}
140
141/// Temporary macro that generates the equivalent of what enum_dispatch will do in the future. We
142/// need this manual macro implementation to delegate to the previous manual implementation for
143/// variants that use the old definitions.
144///
145/// Once everything is handled by this macro we can remove it and replace it with `enum_dispatch`
146macro_rules! derive_unary {
147    ($($name:ident),*) => {
148        #[derive(
149            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
150            serde::Serialize, serde::Deserialize, Hash,
151            mz_lowertest::MzReflect,
152        )]
153        pub enum UnaryFunc {
154            $($name($name),)*
155        }
156
157        impl UnaryFunc {
158            pub fn eval<'a>(
159                &'a self,
160                datums: &[Datum<'a>],
161                temp_storage: &'a RowArena,
162                a: &'a MirScalarExpr,
163            ) -> Result<Datum<'a>, EvalError> {
164                match self {
165                    $(Self::$name(f) => f.eval(datums, temp_storage, a),)*
166                }
167            }
168
169            pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
170                match self {
171                    $(Self::$name(f) => LazyUnaryFunc::output_sql_type(f, input_type),)*
172                }
173            }
174            pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
175                match self {
176                    $(Self::$name(f) => LazyUnaryFunc::output_type(f, input_type),)*
177                }
178            }
179            pub fn propagates_nulls(&self) -> bool {
180                match self {
181                    $(Self::$name(f) => LazyUnaryFunc::propagates_nulls(f),)*
182                }
183            }
184            pub fn introduces_nulls(&self) -> bool {
185                match self {
186                    $(Self::$name(f) => LazyUnaryFunc::introduces_nulls(f),)*
187                }
188            }
189            pub fn preserves_uniqueness(&self) -> bool {
190                match self {
191                    $(Self::$name(f) => LazyUnaryFunc::preserves_uniqueness(f),)*
192                }
193            }
194            pub fn inverse(&self) -> Option<UnaryFunc> {
195                match self {
196                    $(Self::$name(f) => LazyUnaryFunc::inverse(f),)*
197                }
198            }
199            pub fn is_monotone(&self) -> bool {
200                match self {
201                    $(Self::$name(f) => LazyUnaryFunc::is_monotone(f),)*
202                }
203            }
204            pub fn could_error(&self) -> bool {
205                match self {
206                    $(Self::$name(f) => LazyUnaryFunc::could_error(f),)*
207                }
208            }
209            pub fn is_eliminable_cast(&self) -> bool {
210                match self {
211                    $(Self::$name(f) => LazyUnaryFunc::is_eliminable_cast(f),)*
212                }
213            }
214        }
215
216        impl fmt::Display for UnaryFunc {
217            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218                match self {
219                    $(Self::$name(func) => func.fmt(f),)*
220                }
221            }
222        }
223
224        $(
225            impl From<$name> for crate::UnaryFunc {
226                fn from(variant: $name) -> Self {
227                    Self::$name(variant)
228                }
229            }
230        )*
231    }
232}
233
234/// Generates the `VariadicFunc` enum, its `impl` block,
235/// `Display` impl, and `From<InnerType>` impls for each variant.
236///
237/// All variants must use explicit `Name(Type)` syntax. When the variant name equals
238/// the inner type name, write e.g. `And(And)`.
239macro_rules! derive_variadic {
240    ($($name:ident ( $variant:ident )),* $(,)?) => {
241        #[derive(
242            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
243            serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect,
244        )]
245        pub enum VariadicFunc {
246            $($name($variant),)*
247        }
248
249        impl VariadicFunc {
250            pub fn eval<'a>(
251                &'a self,
252                datums: &[Datum<'a>],
253                temp_storage: &'a RowArena,
254                exprs: &'a [MirScalarExpr],
255            ) -> Result<Datum<'a>, EvalError> {
256                match self {
257                    $(Self::$name(f) => f.eval(datums, temp_storage, exprs),)*
258                }
259            }
260
261            pub fn output_sql_type(&self, input_types: Vec<SqlColumnType>) -> SqlColumnType {
262                match self {
263                    $(Self::$name(f) => LazyVariadicFunc::output_type(f, &input_types),)*
264                }
265            }
266
267            /// Computes the representation type of this variadic function.
268            ///
269            /// Wrapper around [`Self::output_sql_type`] that converts to representation types.
270            pub fn output_type(&self, input_types: Vec<ReprColumnType>) -> ReprColumnType {
271                let sql_types = input_types.iter().map(SqlColumnType::from_repr).collect();
272                ReprColumnType::from(&self.output_sql_type(sql_types))
273            }
274
275            pub fn propagates_nulls(&self) -> bool {
276                match self {
277                    $(Self::$name(f) => LazyVariadicFunc::propagates_nulls(f),)*
278                }
279            }
280
281            pub fn introduces_nulls(&self) -> bool {
282                match self {
283                    $(Self::$name(f) => LazyVariadicFunc::introduces_nulls(f),)*
284                }
285            }
286
287            pub fn could_error(&self) -> bool {
288                match self {
289                    $(Self::$name(f) => LazyVariadicFunc::could_error(f),)*
290                }
291            }
292
293            pub fn is_monotone(&self) -> bool {
294                match self {
295                    $(Self::$name(f) => LazyVariadicFunc::is_monotone(f),)*
296                }
297            }
298
299            pub fn is_associative(&self) -> bool {
300                match self {
301                    $(Self::$name(f) => LazyVariadicFunc::is_associative(f),)*
302                }
303            }
304
305            pub fn is_infix_op(&self) -> bool {
306                match self {
307                    $(Self::$name(f) => LazyVariadicFunc::is_infix_op(f),)*
308                }
309            }
310        }
311
312        impl fmt::Display for VariadicFunc {
313            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314                match self {
315                    $(Self::$name(func) => func.fmt(f),)*
316                }
317            }
318        }
319
320        $(
321            impl From<$variant> for crate::VariadicFunc {
322                fn from(variant: $variant) -> Self {
323                    Self::$name(variant)
324                }
325            }
326        )*
327    }
328}
329
330/// Generates the `BinaryFunc` enum, its `impl` block (with 8 delegating methods),
331/// `Display` impl, and `From<InnerType>` impls for each variant.
332///
333/// All variants must use explicit `Name(Type)` syntax. When the variant name equals
334/// the inner type name, write e.g. `AddInt16(AddInt16)`.
335macro_rules! derive_binary {
336    ($($name:ident ( $variant:ident )),* $(,)?) => {
337        #[derive(
338            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
339            serde::Serialize, serde::Deserialize, Hash,
340            mz_lowertest::MzReflect,
341        )]
342        pub enum BinaryFunc {
343            $($name($variant),)*
344        }
345
346        impl BinaryFunc {
347            pub fn eval<'a>(
348                &'a self,
349                datums: &[Datum<'a>],
350                temp_storage: &'a RowArena,
351                exprs: &[&'a MirScalarExpr],
352            ) -> Result<Datum<'a>, EvalError> {
353                match self {
354                    $(Self::$name(f) => f.eval(datums, temp_storage, exprs),)*
355                }
356            }
357
358            pub fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
359                match self {
360                    $(Self::$name(f) => {
361                        LazyBinaryFunc::output_sql_type(f, input_types)
362                    },)*
363                }
364            }
365
366            pub fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
367                match self {
368                    $(Self::$name(f) => LazyBinaryFunc::output_type(f, input_types),)*
369                }
370            }
371
372            pub fn propagates_nulls(&self) -> bool {
373                match self {
374                    $(Self::$name(f) => LazyBinaryFunc::propagates_nulls(f),)*
375                }
376            }
377
378            pub fn introduces_nulls(&self) -> bool {
379                match self {
380                    $(Self::$name(f) => LazyBinaryFunc::introduces_nulls(f),)*
381                }
382            }
383
384            pub fn is_infix_op(&self) -> bool {
385                match self {
386                    $(Self::$name(f) => LazyBinaryFunc::is_infix_op(f),)*
387                }
388            }
389
390            pub fn negate(&self) -> Option<BinaryFunc> {
391                match self {
392                    $(Self::$name(f) => LazyBinaryFunc::negate(f),)*
393                }
394            }
395
396            pub fn could_error(&self) -> bool {
397                match self {
398                    $(Self::$name(f) => LazyBinaryFunc::could_error(f),)*
399                }
400            }
401
402            pub fn is_monotone(&self) -> (bool, bool) {
403                match self {
404                    $(Self::$name(f) => LazyBinaryFunc::is_monotone(f),)*
405                }
406            }
407        }
408
409        impl fmt::Display for BinaryFunc {
410            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
411                match self {
412                    $(Self::$name(func) => func.fmt(f),)*
413                }
414            }
415        }
416
417        $(
418            impl From<$variant> for crate::BinaryFunc {
419                fn from(variant: $variant) -> Self {
420                    Self::$name(variant)
421                }
422            }
423        )*
424    }
425}