mz_expr_derive_impl/
sqlfunc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16/// Modifiers passed as key-value pairs to the `#[sqlfunc]` macro.
17#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19    /// An optional expression that evaluates to a boolean indicating whether the function is
20    /// monotone with respect to its arguments. Defined for unary and binary functions.
21    is_monotone: Option<Expr>,
22    /// The SQL name for the function. Applies to all functions.
23    sqlname: Option<SqlName>,
24    /// Whether the function preserves uniqueness. Applies to unary functions.
25    preserves_uniqueness: Option<Expr>,
26    /// The inverse of the function, if it exists. Applies to unary functions.
27    inverse: Option<Expr>,
28    /// The negated function, if it exists. Applies to binary functions.
29    negate: Option<Expr>,
30    /// Whether the function is an infix operator. Applies to binary functions, and needs to
31    /// be specified.
32    is_infix_op: Option<Expr>,
33    /// The output type of the function, if it cannot be inferred. Applies to all functions.
34    output_type: Option<syn::Path>,
35    /// The output type of the function as an expression. Applies to binary functions.
36    output_type_expr: Option<Expr>,
37    /// Optional expression evaluating to a boolean indicating whether the function could error.
38    /// Applies to all functions.
39    could_error: Option<Expr>,
40    /// Whether the function propagates nulls. Applies to binary functions.
41    propagates_nulls: Option<Expr>,
42    /// Whether the function introduces nulls. Applies to all functions.
43    introduces_nulls: Option<Expr>,
44}
45
46/// A name for the SQL function. It can be either a literal or a macro, thus we
47/// can't use `String` or `syn::Expr` directly.
48#[derive(Debug)]
49enum SqlName {
50    /// A literal string.
51    Literal(syn::Lit),
52    /// A macro expression.
53    Macro(syn::ExprMacro),
54}
55
56impl quote::ToTokens for SqlName {
57    fn to_tokens(&self, tokens: &mut TokenStream) {
58        let name = match self {
59            SqlName::Literal(lit) => quote! { #lit },
60            SqlName::Macro(mac) => quote! { #mac },
61        };
62        tokens.extend(name);
63    }
64}
65
66impl darling::FromMeta for SqlName {
67    fn from_value(value: &Lit) -> darling::Result<Self> {
68        Ok(Self::Literal(value.clone()))
69    }
70    fn from_expr(expr: &Expr) -> darling::Result<Self> {
71        match expr {
72            Expr::Lit(lit) => Self::from_value(&lit.lit),
73            Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
74            // Syn sometimes inserts groups, see `FromMeta::from_expr` for
75            // details.
76            Expr::Group(mac) => Self::from_expr(&mac.expr),
77            _ => Err(darling::Error::unexpected_expr_type(expr)),
78        }
79    }
80}
81
82/// Implementation for the `#[sqlfunc]` macro. The first parameter is the attribute
83/// arguments, the second is the function body. The third parameter indicates
84/// whether to include the test function in the output.
85///
86/// The feature `test` must be enabled to include the test function.
87pub fn sqlfunc(
88    attr: TokenStream,
89    item: TokenStream,
90    include_test: bool,
91) -> darling::Result<TokenStream> {
92    let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
93    let modifiers = Modifiers::from_list(&attr_args).unwrap();
94    let func = syn::parse2::<syn::ItemFn>(item.clone())?;
95
96    let tokens = match determine_parameters_arena(&func) {
97        (1, false) => unary_func(&func, modifiers),
98        (1, true) => Err(darling::Error::custom(
99            "Unary functions do not yet support RowArena.",
100        )),
101        (2, arena) => binary_func(&func, modifiers, arena),
102        (other, _) => Err(darling::Error::custom(format!(
103            "Unsupported function: {} parameters",
104            other
105        ))),
106    }?;
107
108    let test = include_test.then(|| generate_test(attr, item, &func.sig.ident));
109
110    Ok(quote! {
111        #tokens
112        #test
113    })
114}
115
116#[cfg(any(feature = "test", test))]
117fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
118    let attr = attr.to_string();
119    let item = item.to_string();
120    let test_name = Ident::new(&format!("test_{}", name), name.span());
121    let fn_name = name.to_string();
122
123    quote! {
124        #[cfg(test)]
125        #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
126        #[mz_ore::test]
127        fn #test_name() {
128            let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
129            insta::assert_snapshot!(#fn_name, output, &input);
130        }
131    }
132}
133
134#[cfg(not(any(feature = "test", test)))]
135fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
136    quote! {}
137}
138
139/// Determines the number of parameters to the function. Returns the number of parameters and
140/// whether the last parameter is a `RowArena`.
141fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
142    let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
143        if let syn::FnArg::Typed(pat) = last {
144            if let syn::Type::Reference(reference) = &*pat.ty {
145                if let syn::Type::Path(path) = &*reference.elem {
146                    return path.path.is_ident("RowArena");
147                }
148            }
149        }
150        false
151    });
152    let parameters = func.sig.inputs.len();
153    if last_is_arena {
154        (parameters - 1, true)
155    } else {
156        (parameters, false)
157    }
158}
159
160/// Convert an identifier to a camel-cased identifier.
161fn camel_case(ident: &Ident) -> Ident {
162    let mut result = String::new();
163    let mut capitalize_next = true;
164    for c in ident.to_string().chars() {
165        if c == '_' {
166            capitalize_next = true;
167        } else if capitalize_next {
168            result.push(c.to_ascii_uppercase());
169            capitalize_next = false;
170        } else {
171            result.push(c);
172        }
173    }
174    Ident::new(&result, ident.span())
175}
176
177/// Determines the argument type of the nth argument of the function.
178///
179/// Adds a lifetime `'a` to the argument type if it is a reference type.
180///
181/// Panics if the function has fewer than `nth` arguments. Returns an error if
182/// the parameter is a `self` receiver.
183fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
184    match &arg.sig.inputs[nth] {
185        syn::FnArg::Typed(pat) => {
186            // Patch lifetimes to be 'a if reference
187            if let syn::Type::Reference(r) = &*pat.ty {
188                if r.lifetime.is_none() {
189                    let ty = syn::Type::Reference(syn::TypeReference {
190                        lifetime: Some(Lifetime::new("'a", r.span())),
191                        ..r.clone()
192                    });
193                    return Ok(ty);
194                }
195            }
196            Ok((*pat.ty).clone())
197        }
198        _ => Err(syn::Error::new(
199            arg.sig.inputs[nth].span(),
200            "Unsupported argument type",
201        )),
202    }
203}
204
205/// Determine the output type for a function. Returns an error if the function
206/// does not return a value.
207fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
208    match &arg.sig.output {
209        syn::ReturnType::Type(_, ty) => Ok(&*ty),
210        syn::ReturnType::Default => Err(syn::Error::new(
211            arg.sig.output.span(),
212            "Function needs to return a value",
213        )),
214    }
215}
216
217/// Produce a `EagerUnaryFunc` implementation.
218fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
219    let fn_name = &func.sig.ident;
220    let struct_name = camel_case(&func.sig.ident);
221    let input_ty = arg_type(func, 0)?;
222    let output_ty = output_type(func)?;
223    let Modifiers {
224        is_monotone,
225        sqlname,
226        preserves_uniqueness,
227        inverse,
228        is_infix_op,
229        output_type,
230        output_type_expr,
231        negate,
232        could_error,
233        propagates_nulls,
234        introduces_nulls,
235    } = modifiers;
236
237    if is_infix_op.is_some() {
238        return Err(darling::Error::unknown_field(
239            "is_infix_op not supported for unary functions",
240        ));
241    }
242    if output_type_expr.is_some() {
243        return Err(darling::Error::unknown_field(
244            "output_type_expr not supported for unary functions",
245        ));
246    }
247    if negate.is_some() {
248        return Err(darling::Error::unknown_field(
249            "negate not supported for unary functions",
250        ));
251    }
252    if propagates_nulls.is_some() {
253        return Err(darling::Error::unknown_field(
254            "propagates_nulls not supported for unary functions",
255        ));
256    }
257
258    let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
259        quote! {
260            fn preserves_uniqueness(&self) -> bool {
261                #preserves_uniqueness
262            }
263        }
264    });
265
266    let inverse_fn = inverse.as_ref().map(|inverse| {
267        quote! {
268            fn inverse(&self) -> Option<crate::UnaryFunc> {
269                #inverse
270            }
271        }
272    });
273
274    let is_monotone_fn = is_monotone.map(|is_monotone| {
275        quote! {
276            fn is_monotone(&self) -> bool {
277                #is_monotone
278            }
279        }
280    });
281
282    let name = sqlname
283        .as_ref()
284        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
285
286    let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
287        let introduces_nulls_fn = quote! {
288            fn introduces_nulls(&self) -> bool {
289                <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
290            }
291        };
292        let output_type = quote! { <#output_type> };
293        (output_type, Some(introduces_nulls_fn))
294    } else {
295        (quote! { Self::Output }, None)
296    };
297
298    if let Some(introduces_nulls) = introduces_nulls {
299        introduces_nulls_fn = Some(quote! {
300            fn introduces_nulls(&self) -> bool {
301                #introduces_nulls
302            }
303        });
304    }
305
306    let could_error_fn = could_error.map(|could_error| {
307        quote! {
308            fn could_error(&self) -> bool {
309                #could_error
310            }
311        }
312    });
313
314    let result = quote! {
315        #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
316        pub struct #struct_name;
317
318        impl<'a> crate::func::EagerUnaryFunc<'a> for #struct_name {
319            type Input = #input_ty;
320            type Output = #output_ty;
321
322            fn call(&self, a: Self::Input) -> Self::Output {
323                #fn_name(a)
324            }
325
326            fn output_type(&self, input_type: mz_repr::ColumnType) -> mz_repr::ColumnType {
327                use mz_repr::AsColumnType;
328                let output = #output_type::as_column_type();
329                let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
330                let nullable = output.nullable;
331                // The output is nullable if it is nullable by itself or the input is nullable
332                // and this function propagates nulls
333                output.nullable(nullable || (propagates_nulls && input_type.nullable))
334            }
335
336            #could_error_fn
337            #introduces_nulls_fn
338            #inverse_fn
339            #is_monotone_fn
340            #preserves_uniqueness_fn
341        }
342
343        impl std::fmt::Display for #struct_name {
344            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
345                f.write_str(#name)
346            }
347        }
348
349        #func
350    };
351    Ok(result)
352}
353
354/// Produce a `EagerBinaryFunc` implementation.
355fn binary_func(
356    func: &syn::ItemFn,
357    modifiers: Modifiers,
358    arena: bool,
359) -> darling::Result<TokenStream> {
360    let fn_name = &func.sig.ident;
361    let struct_name = camel_case(&func.sig.ident);
362    let input1_ty = arg_type(func, 0)?;
363    let input2_ty = arg_type(func, 1)?;
364    let output_ty = output_type(func)?;
365
366    let Modifiers {
367        is_monotone,
368        sqlname,
369        preserves_uniqueness,
370        inverse,
371        is_infix_op,
372        output_type,
373        output_type_expr,
374        negate,
375        could_error,
376        propagates_nulls,
377        introduces_nulls,
378    } = modifiers;
379
380    if preserves_uniqueness.is_some() {
381        return Err(darling::Error::unknown_field(
382            "preserves_uniqueness not supported for binary functions",
383        ));
384    }
385    if inverse.is_some() {
386        return Err(darling::Error::unknown_field(
387            "inverse not supported for binary functions",
388        ));
389    }
390    if output_type.is_some() && output_type_expr.is_some() {
391        return Err(darling::Error::unknown_field(
392            "output_type and output_type_expr cannot be used together",
393        ));
394    }
395    if output_type_expr.is_some() && introduces_nulls.is_none() {
396        return Err(darling::Error::unknown_field(
397            "output_type_expr requires introduces_nulls",
398        ));
399    }
400
401    let negate_fn = negate.map(|negate| {
402        quote! {
403            fn negate(&self) -> Option<crate::BinaryFunc> {
404                #negate
405            }
406        }
407    });
408
409    let is_monotone_fn = is_monotone.map(|is_monotone| {
410        quote! {
411            fn is_monotone(&self) -> (bool, bool) {
412                #is_monotone
413            }
414        }
415    });
416
417    let name = sqlname
418        .as_ref()
419        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
420
421    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
422        let introduces_nulls_fn = quote! {
423            fn introduces_nulls(&self) -> bool {
424                <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
425            }
426        };
427        let output_type = quote! { <#output_type>::as_column_type() };
428        (output_type, Some(introduces_nulls_fn))
429    } else {
430        (quote! { Self::Output::as_column_type() }, None)
431    };
432
433    if let Some(output_type_expr) = output_type_expr {
434        output_type = quote! { #output_type_expr };
435    }
436
437    if let Some(introduces_nulls) = introduces_nulls {
438        introduces_nulls_fn = Some(quote! {
439            fn introduces_nulls(&self) -> bool {
440                #introduces_nulls
441            }
442        });
443    }
444
445    let arena = if arena {
446        quote! { , temp_storage }
447    } else {
448        quote! {}
449    };
450
451    let could_error_fn = could_error.map(|could_error| {
452        quote! {
453            fn could_error(&self) -> bool {
454                #could_error
455            }
456        }
457    });
458
459    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
460        quote! {
461            fn is_infix_op(&self) -> bool {
462                #is_infix_op
463            }
464        }
465    });
466
467    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
468        quote! {
469            fn propagates_nulls(&self) -> bool {
470                #propagates_nulls
471            }
472        }
473    });
474
475    let result = quote! {
476        #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
477        pub struct #struct_name;
478
479        impl<'a> crate::func::binary::EagerBinaryFunc<'a> for #struct_name {
480            type Input1 = #input1_ty;
481            type Input2 = #input2_ty;
482            type Output = #output_ty;
483
484            fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a mz_repr::RowArena) -> Self::Output {
485                #fn_name(a, b #arena)
486            }
487
488            fn output_type(&self, input_type_a: mz_repr::ColumnType, input_type_b: mz_repr::ColumnType) -> mz_repr::ColumnType {
489                use mz_repr::AsColumnType;
490                let output = #output_type;
491                let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
492                let nullable = output.nullable;
493                // The output is nullable if it is nullable by itself or the input is nullable
494                // and this function propagates nulls
495                output.nullable(nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)))
496            }
497
498            #could_error_fn
499            #introduces_nulls_fn
500            #is_infix_op_fn
501            #is_monotone_fn
502            #negate_fn
503            #propagates_nulls_fn
504        }
505
506        impl std::fmt::Display for #struct_name {
507            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
508                f.write_str(#name)
509            }
510        }
511
512        #func
513
514    };
515    Ok(result)
516}