mz_expr/scalar/func/
macros.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Convenience macro for generating `inverse` values.
11macro_rules! to_unary {
12    ($f:expr) => {
13        Some(crate::UnaryFunc::from($f))
14    };
15}
16
17macro_rules! sqlfunc {
18    // Expand the function name into an attribute if it was omitted
19    (
20        fn $fn_name:ident $($tail:tt)*
21    ) => {
22        sqlfunc!(
23            #[sqlname = stringify!($fn_name)]
24            fn $fn_name $($tail)*
25        );
26    };
27
28    // Add both uniqueness + inverse attributes if they were omitted
29    (
30        #[sqlname = $name:expr]
31        fn $fn_name:ident $($tail:tt)*
32    ) => {
33        sqlfunc!(
34            #[sqlname = $name]
35            #[preserves_uniqueness = false]
36            #[inverse = None]
37            fn $fn_name $($tail)*
38        );
39    };
40
41    (
42        #[sqlname = $name:expr]
43        #[is_monotone = $is_monotone:expr]
44        fn $fn_name:ident $($tail:tt)*
45    ) => {
46        sqlfunc!(
47            #[sqlname = $name]
48            #[preserves_uniqueness = false]
49            #[is_monotone = $is_monotone]
50            fn $fn_name $($tail)*
51        );
52    };
53
54    // Add the inverse attribute if it was omitted
55    (
56        #[sqlname = $name:expr]
57        #[preserves_uniqueness = $preserves_uniqueness:expr]
58        fn $fn_name:ident $($tail:tt)*
59    ) => {
60        sqlfunc!(
61            #[sqlname = $name]
62            #[preserves_uniqueness = $preserves_uniqueness]
63            #[inverse = None]
64            fn $fn_name $($tail)*
65        );
66    };
67
68    (
69        #[sqlname = $name:expr]
70        #[preserves_uniqueness = $preserves_uniqueness:expr]
71        #[is_monotone = $is_monotone:expr]
72        fn $fn_name:ident $($tail:tt)*
73    ) => {
74        sqlfunc!(
75            #[sqlname = $name]
76            #[preserves_uniqueness = $preserves_uniqueness]
77            #[inverse = None]
78            #[is_monotone = $is_monotone]
79            fn $fn_name $($tail)*
80        );
81    };
82
83    // Add the monotone attribute if it was omitted
84    (
85        #[sqlname = $name:expr]
86        #[preserves_uniqueness = $preserves_uniqueness:expr]
87        #[inverse = $inverse:expr]
88        fn $fn_name:ident $($tail:tt)*
89    ) => {
90        sqlfunc!(
91            #[sqlname = $name]
92            #[preserves_uniqueness = $preserves_uniqueness]
93            #[inverse = $inverse]
94            #[is_monotone = false]
95            fn $fn_name $($tail)*
96        );
97    };
98
99    // Add lifetime parameter if it was omitted
100    (
101        #[sqlname = $name:expr]
102        #[preserves_uniqueness = $preserves_uniqueness:expr]
103        #[inverse = $inverse:expr]
104        #[is_monotone = $is_monotone:expr]
105        fn $fn_name:ident ($($params:tt)*) $($tail:tt)*
106    ) => {
107        sqlfunc!(
108            #[sqlname = $name]
109            #[preserves_uniqueness = $preserves_uniqueness]
110            #[inverse = $inverse]
111            #[is_monotone = $is_monotone]
112            fn $fn_name<'a>($($params)*) $($tail)*
113        );
114    };
115
116    // Normalize mut arguments to non-mut ones
117    (
118        #[sqlname = $name:expr]
119        #[preserves_uniqueness = $preserves_uniqueness:expr]
120        #[inverse = $inverse:expr]
121        #[is_monotone = $is_monotone:expr]
122        fn $fn_name:ident<$lt:lifetime>(mut $param_name:ident: $input_ty:ty $(,)?) -> $output_ty:ty
123            $body:block
124    ) => {
125        sqlfunc!(
126            #[sqlname = $name]
127            #[preserves_uniqueness = $preserves_uniqueness]
128            #[inverse = $inverse]
129            #[is_monotone = $is_monotone]
130            fn $fn_name<$lt>($param_name: $input_ty) -> $output_ty {
131                let mut $param_name = $param_name;
132                $body
133            }
134        );
135    };
136
137    (
138        #[sqlname = $name:expr]
139        #[preserves_uniqueness = $preserves_uniqueness:expr]
140        #[inverse = $inverse:expr]
141        #[is_monotone = $is_monotone:expr]
142        fn $fn_name:ident<$lt:lifetime>($param_name:ident: $input_ty:ty $(,)?) -> $output_ty:ty
143            $body:block
144    ) => {
145        paste::paste! {
146            #[mz_expr_derive::sqlfunc(
147                sqlname = $name,
148                preserves_uniqueness = $preserves_uniqueness,
149                inverse = $inverse,
150                is_monotone = $is_monotone,
151            )]
152            #[allow(clippy::extra_unused_lifetimes)]
153            pub fn $fn_name<$lt>($param_name: $input_ty) -> $output_ty {
154                $body
155            }
156
157            mod $fn_name {
158                #[cfg(test)]
159                #[mz_ore::test]
160                // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
161                #[cfg_attr(miri, ignore)]
162                fn test_sqlfunc_macro() {
163                    use crate::func::EagerUnaryFunc;
164                    let f = super::[<$fn_name:camel>];
165                    let input_type = mz_repr::ColumnType {
166                        scalar_type: mz_repr::ScalarType::Float32,
167                        nullable: true,
168                    };
169                    let output_type_nullable = f.output_type(input_type);
170
171                    let input_type = mz_repr::ColumnType {
172                        scalar_type: mz_repr::ScalarType::Float32,
173                        nullable: false,
174                    };
175                    let output_type_nonnullable = f.output_type(input_type);
176
177                    let preserves_uniqueness = f.preserves_uniqueness();
178                    let inverse = f.inverse();
179                    let is_monotone = f.is_monotone();
180                    let propagates_nulls = f.propagates_nulls();
181                    let introduces_nulls = f.introduces_nulls();
182                    let could_error = f.could_error();
183
184                    #[derive(Debug)]
185                    #[allow(unused)]
186                    struct Info {
187                        output_type_nullable: mz_repr::ColumnType,
188                        output_type_nonnullable: mz_repr::ColumnType,
189                        preserves_uniqueness: bool,
190                        inverse: Option<crate::UnaryFunc>,
191                        is_monotone: bool,
192                        propagates_nulls: bool,
193                        introduces_nulls: bool,
194                        could_error: bool,
195                    }
196
197                    let info = Info {
198                        output_type_nullable,
199                        output_type_nonnullable,
200                        preserves_uniqueness,
201                        inverse,
202                        is_monotone,
203                        propagates_nulls,
204                        introduces_nulls,
205                        could_error,
206                    };
207
208                    insta::assert_debug_snapshot!(info);
209                }
210
211            }
212        }
213    };
214}
215
216#[cfg(test)]
217mod test {
218    use mz_repr::ScalarType;
219
220    use crate::EvalError;
221    use crate::scalar::func::LazyUnaryFunc;
222
223    sqlfunc!(
224        #[sqlname = "INFALLIBLE"]
225        fn infallible1(a: f32) -> f32 {
226            a
227        }
228    );
229
230    sqlfunc!(
231        fn infallible2(a: Option<f32>) -> f32 {
232            a.unwrap_or_default()
233        }
234    );
235
236    sqlfunc!(
237        fn infallible3(a: f32) -> Option<f32> {
238            Some(a)
239        }
240    );
241
242    #[mz_ore::test]
243    fn elision_rules_infallible() {
244        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
245        assert!(Infallible1.propagates_nulls());
246        assert!(!Infallible1.introduces_nulls());
247
248        assert!(!Infallible2.propagates_nulls());
249        assert!(!Infallible2.introduces_nulls());
250
251        assert!(Infallible3.propagates_nulls());
252        assert!(Infallible3.introduces_nulls());
253    }
254
255    #[mz_ore::test]
256    fn output_types_infallible() {
257        assert_eq!(
258            Infallible1.output_type(ScalarType::Float32.nullable(true)),
259            ScalarType::Float32.nullable(true)
260        );
261        assert_eq!(
262            Infallible1.output_type(ScalarType::Float32.nullable(false)),
263            ScalarType::Float32.nullable(false)
264        );
265
266        assert_eq!(
267            Infallible2.output_type(ScalarType::Float32.nullable(true)),
268            ScalarType::Float32.nullable(false)
269        );
270        assert_eq!(
271            Infallible2.output_type(ScalarType::Float32.nullable(false)),
272            ScalarType::Float32.nullable(false)
273        );
274
275        assert_eq!(
276            Infallible3.output_type(ScalarType::Float32.nullable(true)),
277            ScalarType::Float32.nullable(true)
278        );
279        assert_eq!(
280            Infallible3.output_type(ScalarType::Float32.nullable(false)),
281            ScalarType::Float32.nullable(true)
282        );
283    }
284
285    sqlfunc!(
286        fn fallible1(a: f32) -> Result<f32, EvalError> {
287            Ok(a)
288        }
289    );
290
291    sqlfunc!(
292        fn fallible2(a: Option<f32>) -> Result<f32, EvalError> {
293            Ok(a.unwrap_or_default())
294        }
295    );
296
297    sqlfunc!(
298        fn fallible3(a: f32) -> Result<Option<f32>, EvalError> {
299            Ok(Some(a))
300        }
301    );
302
303    #[mz_ore::test]
304    fn elision_rules_fallible() {
305        assert!(Fallible1.propagates_nulls());
306        assert!(!Fallible1.introduces_nulls());
307
308        assert!(!Fallible2.propagates_nulls());
309        assert!(!Fallible2.introduces_nulls());
310
311        assert!(Fallible3.propagates_nulls());
312        assert!(Fallible3.introduces_nulls());
313    }
314
315    #[mz_ore::test]
316    fn output_types_fallible() {
317        assert_eq!(
318            Fallible1.output_type(ScalarType::Float32.nullable(true)),
319            ScalarType::Float32.nullable(true)
320        );
321        assert_eq!(
322            Fallible1.output_type(ScalarType::Float32.nullable(false)),
323            ScalarType::Float32.nullable(false)
324        );
325
326        assert_eq!(
327            Fallible2.output_type(ScalarType::Float32.nullable(true)),
328            ScalarType::Float32.nullable(false)
329        );
330        assert_eq!(
331            Fallible2.output_type(ScalarType::Float32.nullable(false)),
332            ScalarType::Float32.nullable(false)
333        );
334
335        assert_eq!(
336            Fallible3.output_type(ScalarType::Float32.nullable(true)),
337            ScalarType::Float32.nullable(true)
338        );
339        assert_eq!(
340            Fallible3.output_type(ScalarType::Float32.nullable(false)),
341            ScalarType::Float32.nullable(true)
342        );
343    }
344}
345
346/// Temporary macro that generates the equivalent of what enum_dispatch will do in the future. We
347/// need this manual macro implementation to delegate to the previous manual implementation for
348/// variants that use the old definitions.
349///
350/// Once everything is handled by this macro we can remove it and replace it with `enum_dispatch`
351macro_rules! derive_unary {
352    ($($name:ident),*) => {
353        #[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
354        pub enum UnaryFunc {
355            $($name($name),)*
356        }
357
358        impl UnaryFunc {
359            pub fn eval<'a>(
360                &'a self,
361                datums: &[Datum<'a>],
362                temp_storage: &'a RowArena,
363                a: &'a MirScalarExpr,
364            ) -> Result<Datum<'a>, EvalError> {
365                match self {
366                    $(Self::$name(f) => f.eval(datums, temp_storage, a),)*
367                }
368            }
369
370            pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
371                match self {
372                    $(Self::$name(f) => LazyUnaryFunc::output_type(f, input_type),)*
373                }
374            }
375            pub fn propagates_nulls(&self) -> bool {
376                match self {
377                    $(Self::$name(f) => LazyUnaryFunc::propagates_nulls(f),)*
378                }
379            }
380            pub fn introduces_nulls(&self) -> bool {
381                match self {
382                    $(Self::$name(f) => LazyUnaryFunc::introduces_nulls(f),)*
383                }
384            }
385            pub fn preserves_uniqueness(&self) -> bool {
386                match self {
387                    $(Self::$name(f) => LazyUnaryFunc::preserves_uniqueness(f),)*
388                }
389            }
390            pub fn inverse(&self) -> Option<UnaryFunc> {
391                match self {
392                    $(Self::$name(f) => LazyUnaryFunc::inverse(f),)*
393                }
394            }
395            pub fn is_monotone(&self) -> bool {
396                match self {
397                    $(Self::$name(f) => LazyUnaryFunc::is_monotone(f),)*
398                }
399            }
400            pub fn could_error(&self) -> bool {
401                match self {
402                    $(Self::$name(f) => LazyUnaryFunc::could_error(f),)*
403                }
404            }
405        }
406
407        impl fmt::Display for UnaryFunc {
408            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
409                match self {
410                    $(Self::$name(func) => func.fmt(f),)*
411                }
412            }
413        }
414
415        $(
416            impl From<$name> for crate::UnaryFunc {
417                fn from(variant: $name) -> Self {
418                    Self::$name(variant)
419                }
420            }
421        )*
422    }
423}