mz_expr_derive_impl/
sqlfunc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime};
15
16/// Modifiers passed as key-value pairs to the `#[sqlfunc]` macro.
17#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19    /// An optional expression that evaluates to a boolean indicating whether the function is
20    /// monotone with respect to its arguments. Defined for unary and binary functions.
21    is_monotone: Option<Expr>,
22    /// The SQL name for the function. Applies to all functions.
23    sqlname: Option<String>,
24    /// Whether the function preserves uniqueness. Applies to unary functions.
25    preserves_uniqueness: Option<Expr>,
26    /// The inverse of the function, if it exists. Applies to unary functions.
27    inverse: Option<Expr>,
28    /// The negated function, if it exists. Applies to binary functions.
29    negate: Option<Expr>,
30    /// Whether the function is an infix operator. Applies to binary functions, and needs to
31    /// be specified.
32    is_infix_op: Option<Expr>,
33    /// The output type of the function, if it cannot be inferred. Applies to all functions.
34    output_type: Option<syn::Path>,
35    /// Optional expression evaluating to a boolean indicating whether the function could error.
36    /// Applies to all functions.
37    could_error: Option<Expr>,
38    /// Whether the function propagates nulls. Applies to binary functions.
39    propagates_nulls: Option<Expr>,
40    /// Whether the function introduces nulls. Applies to all functions.
41    introduces_nulls: Option<Expr>,
42}
43
44/// Implementation for the `#[sqlfunc]` macro. The first parameter is the attribute
45/// arguments, the second is the function body. The third parameter indicates
46/// whether to include the test function in the output.
47///
48/// The feature `test` must be enabled to include the test function.
49pub fn sqlfunc(
50    attr: TokenStream,
51    item: TokenStream,
52    include_test: bool,
53) -> darling::Result<TokenStream> {
54    let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
55    let modifiers = Modifiers::from_list(&attr_args)?;
56    let func = syn::parse2::<syn::ItemFn>(item.clone())?;
57
58    let tokens = match determine_parameters_arena(&func) {
59        (1, false) => unary_func(&func, modifiers),
60        (1, true) => Err(darling::Error::custom(
61            "Unary functions do not yet support RowArena.",
62        )),
63        (2, arena) => binary_func(&func, modifiers, arena),
64        (other, _) => Err(darling::Error::custom(format!(
65            "Unsupported function: {} parameters",
66            other
67        ))),
68    }?;
69
70    let test = include_test.then(|| generate_test(attr, item, &func.sig.ident));
71
72    Ok(quote! {
73        #tokens
74        #test
75    })
76}
77
78#[cfg(any(feature = "test", test))]
79fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
80    let attr = attr.to_string();
81    let item = item.to_string();
82    let test_name = Ident::new(&format!("test_{}", name), name.span());
83    let fn_name = name.to_string();
84
85    quote! {
86        #[cfg(test)]
87        #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
88        #[mz_ore::test]
89        fn #test_name() {
90            let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
91            insta::assert_snapshot!(#fn_name, output, &input);
92        }
93    }
94}
95
96#[cfg(not(any(feature = "test", test)))]
97fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
98    quote! {}
99}
100
101/// Determines the number of parameters to the function. Returns the number of parameters and
102/// whether the last parameter is a `RowArena`.
103fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
104    let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
105        if let syn::FnArg::Typed(pat) = last {
106            if let syn::Type::Reference(reference) = &*pat.ty {
107                if let syn::Type::Path(path) = &*reference.elem {
108                    return path.path.is_ident("RowArena");
109                }
110            }
111        }
112        false
113    });
114    let parameters = func.sig.inputs.len();
115    if last_is_arena {
116        (parameters - 1, true)
117    } else {
118        (parameters, false)
119    }
120}
121
122/// Convert an identifier to a camel-cased identifier.
123fn camel_case(ident: &Ident) -> Ident {
124    let mut result = String::new();
125    let mut capitalize_next = true;
126    for c in ident.to_string().chars() {
127        if c == '_' {
128            capitalize_next = true;
129        } else if capitalize_next {
130            result.push(c.to_ascii_uppercase());
131            capitalize_next = false;
132        } else {
133            result.push(c);
134        }
135    }
136    Ident::new(&result, ident.span())
137}
138
139/// Determines the argument type of the nth argument of the function.
140///
141/// Adds a lifetime `'a` to the argument type if it is a reference type.
142///
143/// Panics if the function has fewer than `nth` arguments. Returns an error if
144/// the parameter is a `self` receiver.
145fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
146    match &arg.sig.inputs[nth] {
147        syn::FnArg::Typed(pat) => {
148            // Patch lifetimes to be 'a if reference
149            if let syn::Type::Reference(r) = &*pat.ty {
150                if r.lifetime.is_none() {
151                    let ty = syn::Type::Reference(syn::TypeReference {
152                        lifetime: Some(Lifetime::new("'a", r.span())),
153                        ..r.clone()
154                    });
155                    return Ok(ty);
156                }
157            }
158            Ok((*pat.ty).clone())
159        }
160        _ => Err(syn::Error::new(
161            arg.sig.inputs[nth].span(),
162            "Unsupported argument type",
163        )),
164    }
165}
166
167/// Determine the output type for a function. Returns an error if the function
168/// does not return a value.
169fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
170    match &arg.sig.output {
171        syn::ReturnType::Type(_, ty) => Ok(&*ty),
172        syn::ReturnType::Default => Err(syn::Error::new(
173            arg.sig.output.span(),
174            "Function needs to return a value",
175        )),
176    }
177}
178
179/// Produce a `EagerUnaryFunc` implementation.
180fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
181    let fn_name = &func.sig.ident;
182    let struct_name = camel_case(&func.sig.ident);
183    let input_ty = arg_type(func, 0)?;
184    let output_ty = output_type(func)?;
185    let Modifiers {
186        is_monotone,
187        sqlname,
188        preserves_uniqueness,
189        inverse,
190        is_infix_op,
191        output_type,
192        negate,
193        could_error,
194        propagates_nulls,
195        introduces_nulls,
196    } = modifiers;
197
198    if is_infix_op.is_some() {
199        return Err(darling::Error::unknown_field(
200            "is_infix_op not supported for unary functions",
201        ));
202    }
203    if negate.is_some() {
204        return Err(darling::Error::unknown_field(
205            "negate not supported for unary functions",
206        ));
207    }
208    if propagates_nulls.is_some() {
209        return Err(darling::Error::unknown_field(
210            "propagates_nulls not supported for unary functions",
211        ));
212    }
213
214    let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
215        quote! {
216            fn preserves_uniqueness(&self) -> bool {
217                #preserves_uniqueness
218            }
219        }
220    });
221
222    let inverse_fn = inverse.as_ref().map(|inverse| {
223        quote! {
224            fn inverse(&self) -> Option<crate::UnaryFunc> {
225                #inverse
226            }
227        }
228    });
229
230    let is_monotone_fn = is_monotone.map(|is_monotone| {
231        quote! {
232            fn is_monotone(&self) -> bool {
233                #is_monotone
234            }
235        }
236    });
237
238    let name = if let Some(sqlname) = sqlname {
239        quote! {
240            #sqlname
241        }
242    } else {
243        quote! {
244            stringify!(#fn_name)
245        }
246    };
247
248    let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
249        let introduces_nulls_fn = quote! {
250            fn introduces_nulls(&self) -> bool {
251                <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
252            }
253        };
254        let output_type = quote! { <#output_type> };
255        (output_type, Some(introduces_nulls_fn))
256    } else {
257        (quote! { Self::Output }, None)
258    };
259
260    if let Some(introduces_nulls) = introduces_nulls {
261        introduces_nulls_fn = Some(quote! {
262            fn introduces_nulls(&self) -> bool {
263                #introduces_nulls
264            }
265        });
266    }
267
268    let could_error_fn = could_error.map(|could_error| {
269        quote! {
270            fn could_error(&self) -> bool {
271                #could_error
272            }
273        }
274    });
275
276    let result = quote! {
277        #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
278        pub struct #struct_name;
279
280        impl<'a> crate::func::EagerUnaryFunc<'a> for #struct_name {
281            type Input = #input_ty;
282            type Output = #output_ty;
283
284            fn call(&self, a: Self::Input) -> Self::Output {
285                #fn_name(a)
286            }
287
288            fn output_type(&self, input_type: mz_repr::ColumnType) -> mz_repr::ColumnType {
289                use mz_repr::AsColumnType;
290                let output = #output_type::as_column_type();
291                let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
292                let nullable = output.nullable;
293                // The output is nullable if it is nullable by itself or the input is nullable
294                // and this function propagates nulls
295                output.nullable(nullable || (propagates_nulls && input_type.nullable))
296            }
297
298            #could_error_fn
299            #introduces_nulls_fn
300            #inverse_fn
301            #is_monotone_fn
302            #preserves_uniqueness_fn
303        }
304
305        impl std::fmt::Display for #struct_name {
306            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
307                f.write_str(#name)
308            }
309        }
310
311        #func
312    };
313    Ok(result)
314}
315
316/// Produce a `EagerBinaryFunc` implementation.
317fn binary_func(
318    func: &syn::ItemFn,
319    modifiers: Modifiers,
320    arena: bool,
321) -> darling::Result<TokenStream> {
322    let fn_name = &func.sig.ident;
323    let struct_name = camel_case(&func.sig.ident);
324    let input1_ty = arg_type(func, 0)?;
325    let input2_ty = arg_type(func, 1)?;
326    let output_ty = output_type(func)?;
327
328    let Modifiers {
329        is_monotone,
330        sqlname,
331        preserves_uniqueness,
332        inverse,
333        is_infix_op,
334        output_type,
335        negate,
336        could_error,
337        propagates_nulls,
338        introduces_nulls,
339    } = modifiers;
340
341    if preserves_uniqueness.is_some() {
342        return Err(darling::Error::unknown_field(
343            "preserves_uniqueness not supported for binary functions",
344        ));
345    }
346    if inverse.is_some() {
347        return Err(darling::Error::unknown_field(
348            "inverse not supported for binary functions",
349        ));
350    }
351
352    let negate_fn = negate.map(|negate| {
353        quote! {
354            fn negate(&self) -> Option<crate::BinaryFunc> {
355                #negate
356            }
357        }
358    });
359
360    let is_monotone_fn = is_monotone.map(|is_monotone| {
361        quote! {
362            fn is_monotone(&self) -> (bool, bool) {
363                #is_monotone
364            }
365        }
366    });
367
368    let name = if let Some(sqlname) = sqlname {
369        quote! {
370            #sqlname
371        }
372    } else {
373        quote! {
374            stringify!(#fn_name)
375        }
376    };
377
378    let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
379        let introduces_nulls_fn = quote! {
380            fn introduces_nulls(&self) -> bool {
381                <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
382            }
383        };
384        let output_type = quote! { <#output_type> };
385        (output_type, Some(introduces_nulls_fn))
386    } else {
387        (quote! { Self::Output }, None)
388    };
389
390    if let Some(introduces_nulls) = introduces_nulls {
391        introduces_nulls_fn = Some(quote! {
392            fn introduces_nulls(&self) -> bool {
393                #introduces_nulls
394            }
395        });
396    }
397
398    let arena = if arena {
399        quote! { , temp_storage }
400    } else {
401        quote! {}
402    };
403
404    let could_error_fn = could_error.map(|could_error| {
405        quote! {
406            fn could_error(&self) -> bool {
407                #could_error
408            }
409        }
410    });
411
412    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
413        quote! {
414            fn is_infix_op(&self) -> bool {
415                #is_infix_op
416            }
417        }
418    });
419
420    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
421        quote! {
422            fn propagates_nulls(&self) -> bool {
423                #propagates_nulls
424            }
425        }
426    });
427
428    let result = quote! {
429        #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
430        pub struct #struct_name;
431
432        impl<'a> crate::func::binary::EagerBinaryFunc<'a> for #struct_name {
433            type Input1 = #input1_ty;
434            type Input2 = #input2_ty;
435            type Output = #output_ty;
436
437            fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a mz_repr::RowArena) -> Self::Output {
438                #fn_name(a, b #arena)
439            }
440
441            fn output_type(&self, input_type_a: mz_repr::ColumnType, input_type_b: mz_repr::ColumnType) -> mz_repr::ColumnType {
442                use mz_repr::AsColumnType;
443                let output = #output_type::as_column_type();
444                let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
445                let nullable = output.nullable;
446                // The output is nullable if it is nullable by itself or the input is nullable
447                // and this function propagates nulls
448                output.nullable(nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)))
449            }
450
451            #could_error_fn
452            #introduces_nulls_fn
453            #is_infix_op_fn
454            #is_monotone_fn
455            #negate_fn
456            #propagates_nulls_fn
457        }
458
459        impl std::fmt::Display for #struct_name {
460            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
461                f.write_str(#name)
462            }
463        }
464
465        #func
466
467    };
468    Ok(result)
469}