Skip to main content

mz_expr_derive_impl/
sqlfunc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16/// Modifiers passed as key-value pairs to the `#[sqlfunc]` macro.
17#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19    /// An optional expression that evaluates to a boolean indicating whether the function is
20    /// monotone with respect to its arguments. Defined for unary and binary functions.
21    is_monotone: Option<Expr>,
22    /// The SQL name for the function. Applies to all functions.
23    sqlname: Option<SqlName>,
24    /// Whether the function preserves uniqueness. Applies to unary functions.
25    preserves_uniqueness: Option<Expr>,
26    /// The inverse of the function, if it exists. Applies to unary functions.
27    inverse: Option<Expr>,
28    /// The negated function, if it exists. Applies to binary functions.
29    negate: Option<Expr>,
30    /// Whether the function is an infix operator. Applies to binary functions, and needs to
31    /// be specified.
32    is_infix_op: Option<Expr>,
33    /// The output type of the function, if it cannot be inferred. Applies to all functions.
34    output_type: Option<syn::Path>,
35    /// The output type of the function as an expression. Applies to binary functions.
36    output_type_expr: Option<Expr>,
37    /// Optional expression evaluating to a boolean indicating whether the function could error.
38    /// Applies to all functions.
39    could_error: Option<Expr>,
40    /// Whether the function propagates nulls. Applies to binary functions.
41    propagates_nulls: Option<Expr>,
42    /// Whether the function introduces nulls. Applies to all functions.
43    introduces_nulls: Option<Expr>,
44    /// Whether to generate a snapshot test for the function. Defaults to false.
45    test: Option<bool>,
46}
47
48/// A name for the SQL function. It can be either a literal or a macro, thus we
49/// can't use `String` or `syn::Expr` directly.
50#[derive(Debug)]
51enum SqlName {
52    /// A literal string.
53    Literal(syn::Lit),
54    /// A macro expression.
55    Macro(syn::ExprMacro),
56}
57
58impl quote::ToTokens for SqlName {
59    fn to_tokens(&self, tokens: &mut TokenStream) {
60        let name = match self {
61            SqlName::Literal(lit) => quote! { #lit },
62            SqlName::Macro(mac) => quote! { #mac },
63        };
64        tokens.extend(name);
65    }
66}
67
68impl darling::FromMeta for SqlName {
69    fn from_value(value: &Lit) -> darling::Result<Self> {
70        Ok(Self::Literal(value.clone()))
71    }
72    fn from_expr(expr: &Expr) -> darling::Result<Self> {
73        match expr {
74            Expr::Lit(lit) => Self::from_value(&lit.lit),
75            Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
76            // Syn sometimes inserts groups, see `FromMeta::from_expr` for
77            // details.
78            Expr::Group(mac) => Self::from_expr(&mac.expr),
79            _ => Err(darling::Error::unexpected_expr_type(expr)),
80        }
81    }
82}
83
84/// Implementation for the `#[sqlfunc]` macro. The first parameter is the attribute
85/// arguments, the second is the function body. The third parameter indicates
86/// whether to include the test function in the output.
87///
88/// The feature `test` must be enabled to include the test function.
89pub fn sqlfunc(
90    attr: TokenStream,
91    item: TokenStream,
92    include_test: bool,
93) -> darling::Result<TokenStream> {
94    let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
95    let modifiers = Modifiers::from_list(&attr_args).unwrap();
96    let generate_tests = modifiers.test.unwrap_or(false);
97    let func = syn::parse2::<syn::ItemFn>(item.clone())?;
98
99    let tokens = match determine_parameters_arena(&func) {
100        (1, false) => unary_func(&func, modifiers),
101        (1, true) => Err(darling::Error::custom(
102            "Unary functions do not yet support RowArena.",
103        )),
104        (2, arena) => binary_func(&func, modifiers, arena),
105        (other, _) => Err(darling::Error::custom(format!(
106            "Unsupported function: {} parameters",
107            other
108        ))),
109    }?;
110
111    let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
112
113    Ok(quote! {
114        #tokens
115        #test
116    })
117}
118
119#[cfg(any(feature = "test", test))]
120fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
121    let attr = attr.to_string();
122    let item = item.to_string();
123    let test_name = Ident::new(&format!("test_{}", name), name.span());
124    let fn_name = name.to_string();
125
126    quote! {
127        #[cfg(test)]
128        #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
129        #[mz_ore::test]
130        fn #test_name() {
131            let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
132            insta::assert_snapshot!(#fn_name, output, &input);
133        }
134    }
135}
136
137#[cfg(not(any(feature = "test", test)))]
138fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
139    quote! {}
140}
141
142/// Determines the number of parameters to the function. Returns the number of parameters and
143/// whether the last parameter is a `RowArena`.
144fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
145    let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
146        if let syn::FnArg::Typed(pat) = last {
147            if let syn::Type::Reference(reference) = &*pat.ty {
148                if let syn::Type::Path(path) = &*reference.elem {
149                    return path.path.is_ident("RowArena");
150                }
151            }
152        }
153        false
154    });
155    let parameters = func.sig.inputs.len();
156    if last_is_arena {
157        (parameters - 1, true)
158    } else {
159        (parameters, false)
160    }
161}
162
163/// Convert an identifier to a camel-cased identifier.
164fn camel_case(ident: &Ident) -> Ident {
165    let mut result = String::new();
166    let mut capitalize_next = true;
167    for c in ident.to_string().chars() {
168        if c == '_' {
169            capitalize_next = true;
170        } else if capitalize_next {
171            result.push(c.to_ascii_uppercase());
172            capitalize_next = false;
173        } else {
174            result.push(c);
175        }
176    }
177    Ident::new(&result, ident.span())
178}
179
180/// Determines the argument type of the nth argument of the function.
181///
182/// Adds a lifetime `'a` to the argument type if it is a reference type.
183///
184/// Panics if the function has fewer than `nth` arguments. Returns an error if
185/// the parameter is a `self` receiver.
186fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
187    match &arg.sig.inputs[nth] {
188        syn::FnArg::Typed(pat) => {
189            // Patch lifetimes to be 'a if reference
190            if let syn::Type::Reference(r) = &*pat.ty {
191                if r.lifetime.is_none() {
192                    let ty = syn::Type::Reference(syn::TypeReference {
193                        lifetime: Some(Lifetime::new("'a", r.span())),
194                        ..r.clone()
195                    });
196                    return Ok(ty);
197                }
198            }
199            Ok((*pat.ty).clone())
200        }
201        _ => Err(syn::Error::new(
202            arg.sig.inputs[nth].span(),
203            "Unsupported argument type",
204        )),
205    }
206}
207
208/// Determine the output type for a function. Returns an error if the function
209/// does not return a value.
210fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
211    match &arg.sig.output {
212        syn::ReturnType::Type(_, ty) => Ok(&*ty),
213        syn::ReturnType::Default => Err(syn::Error::new(
214            arg.sig.output.span(),
215            "Function needs to return a value",
216        )),
217    }
218}
219
220/// Produce a `EagerUnaryFunc` implementation.
221fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
222    let fn_name = &func.sig.ident;
223    let struct_name = camel_case(&func.sig.ident);
224    let input_ty = arg_type(func, 0)?;
225    let output_ty = output_type(func)?;
226    let Modifiers {
227        is_monotone,
228        sqlname,
229        preserves_uniqueness,
230        inverse,
231        is_infix_op,
232        output_type,
233        output_type_expr,
234        negate,
235        could_error,
236        propagates_nulls,
237        introduces_nulls,
238        test: _,
239    } = modifiers;
240
241    if is_infix_op.is_some() {
242        return Err(darling::Error::unknown_field(
243            "is_infix_op not supported for unary functions",
244        ));
245    }
246    if output_type.is_some() && output_type_expr.is_some() {
247        return Err(darling::Error::unknown_field(
248            "output_type and output_type_expr cannot be used together",
249        ));
250    }
251    if output_type_expr.is_some() && introduces_nulls.is_none() {
252        return Err(darling::Error::unknown_field(
253            "output_type_expr requires introduces_nulls",
254        ));
255    }
256    if negate.is_some() {
257        return Err(darling::Error::unknown_field(
258            "negate not supported for unary functions",
259        ));
260    }
261    if propagates_nulls.is_some() {
262        return Err(darling::Error::unknown_field(
263            "propagates_nulls not supported for unary functions",
264        ));
265    }
266
267    let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
268        quote! {
269            fn preserves_uniqueness(&self) -> bool {
270                #preserves_uniqueness
271            }
272        }
273    });
274
275    let inverse_fn = inverse.as_ref().map(|inverse| {
276        quote! {
277            fn inverse(&self) -> Option<crate::UnaryFunc> {
278                #inverse
279            }
280        }
281    });
282
283    let is_monotone_fn = is_monotone.map(|is_monotone| {
284        quote! {
285            fn is_monotone(&self) -> bool {
286                #is_monotone
287            }
288        }
289    });
290
291    let name = sqlname
292        .as_ref()
293        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
294
295    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
296        let introduces_nulls_fn = quote! {
297            fn introduces_nulls(&self) -> bool {
298                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
299            }
300        };
301        let output_type = quote! { <#output_type>::as_column_type() };
302        (output_type, Some(introduces_nulls_fn))
303    } else {
304        (quote! { Self::Output::as_column_type() }, None)
305    };
306
307    if let Some(output_type_expr) = output_type_expr {
308        output_type = quote! { #output_type_expr };
309    }
310
311    if let Some(introduces_nulls) = introduces_nulls {
312        introduces_nulls_fn = Some(quote! {
313            fn introduces_nulls(&self) -> bool {
314                #introduces_nulls
315            }
316        });
317    }
318
319    let could_error_fn = could_error.map(|could_error| {
320        quote! {
321            fn could_error(&self) -> bool {
322                #could_error
323            }
324        }
325    });
326
327    let result = quote! {
328        #[derive(
329            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
330            Debug, Eq, PartialEq, serde::Serialize,
331            serde::Deserialize, Hash, mz_lowertest::MzReflect,
332        )]
333        pub struct #struct_name;
334
335        impl crate::func::EagerUnaryFunc for #struct_name {
336            type Input<'a> = #input_ty;
337            type Output<'a> = #output_ty;
338
339            fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
340                #fn_name(a)
341            }
342
343            fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType {
344                use mz_repr::AsColumnType;
345                let output = #output_type;
346                let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
347                let nullable = output.nullable;
348                // The output is nullable if it is nullable by itself or the input is nullable
349                // and this function propagates nulls
350                output.nullable(nullable || (propagates_nulls && input_type.nullable))
351            }
352
353            #could_error_fn
354            #introduces_nulls_fn
355            #inverse_fn
356            #is_monotone_fn
357            #preserves_uniqueness_fn
358        }
359
360        impl std::fmt::Display for #struct_name {
361            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
362                f.write_str(#name)
363            }
364        }
365
366        #func
367    };
368    Ok(result)
369}
370
371/// Produce a `EagerBinaryFunc` implementation.
372fn binary_func(
373    func: &syn::ItemFn,
374    modifiers: Modifiers,
375    arena: bool,
376) -> darling::Result<TokenStream> {
377    let fn_name = &func.sig.ident;
378    let struct_name = camel_case(&func.sig.ident);
379    let input1_ty = arg_type(func, 0)?;
380    let input2_ty = arg_type(func, 1)?;
381    let output_ty = output_type(func)?;
382
383    let Modifiers {
384        is_monotone,
385        sqlname,
386        preserves_uniqueness,
387        inverse,
388        is_infix_op,
389        output_type,
390        output_type_expr,
391        negate,
392        could_error,
393        propagates_nulls,
394        introduces_nulls,
395        test: _,
396    } = modifiers;
397
398    if preserves_uniqueness.is_some() {
399        return Err(darling::Error::unknown_field(
400            "preserves_uniqueness not supported for binary functions",
401        ));
402    }
403    if inverse.is_some() {
404        return Err(darling::Error::unknown_field(
405            "inverse not supported for binary functions",
406        ));
407    }
408    if output_type.is_some() && output_type_expr.is_some() {
409        return Err(darling::Error::unknown_field(
410            "output_type and output_type_expr cannot be used together",
411        ));
412    }
413    if output_type_expr.is_some() && introduces_nulls.is_none() {
414        return Err(darling::Error::unknown_field(
415            "output_type_expr requires introduces_nulls",
416        ));
417    }
418
419    let negate_fn = negate.map(|negate| {
420        quote! {
421            fn negate(&self) -> Option<crate::BinaryFunc> {
422                #negate
423            }
424        }
425    });
426
427    let is_monotone_fn = is_monotone.map(|is_monotone| {
428        quote! {
429            fn is_monotone(&self) -> (bool, bool) {
430                #is_monotone
431            }
432        }
433    });
434
435    let name = sqlname
436        .as_ref()
437        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
438
439    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
440        let introduces_nulls_fn = quote! {
441            fn introduces_nulls(&self) -> bool {
442                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
443            }
444        };
445        let output_type = quote! { <#output_type>::as_column_type() };
446        (output_type, Some(introduces_nulls_fn))
447    } else {
448        (quote! { Self::Output::as_column_type() }, None)
449    };
450
451    if let Some(output_type_expr) = output_type_expr {
452        output_type = quote! { #output_type_expr };
453    }
454
455    if let Some(introduces_nulls) = introduces_nulls {
456        introduces_nulls_fn = Some(quote! {
457            fn introduces_nulls(&self) -> bool {
458                #introduces_nulls
459            }
460        });
461    }
462
463    let arena = if arena {
464        quote! { , temp_storage }
465    } else {
466        quote! {}
467    };
468
469    let could_error_fn = could_error.map(|could_error| {
470        quote! {
471            fn could_error(&self) -> bool {
472                #could_error
473            }
474        }
475    });
476
477    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
478        quote! {
479            fn is_infix_op(&self) -> bool {
480                #is_infix_op
481            }
482        }
483    });
484
485    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
486        quote! {
487            fn propagates_nulls(&self) -> bool {
488                #propagates_nulls
489            }
490        }
491    });
492
493    let result = quote! {
494        #[derive(
495            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
496            Debug, Eq, PartialEq, serde::Serialize,
497            serde::Deserialize, Hash, mz_lowertest::MzReflect,
498        )]
499        pub struct #struct_name;
500
501        impl crate::func::binary::EagerBinaryFunc for #struct_name {
502            type Input<'a> = (#input1_ty, #input2_ty);
503            type Output<'a> = #output_ty;
504
505            fn call<'a>(
506                &self,
507                (a, b): Self::Input<'a>,
508                temp_storage: &'a mz_repr::RowArena
509            ) -> Self::Output<'a> {
510                #fn_name(a, b #arena)
511            }
512
513            fn output_type(
514                &self,
515                input_types: &[mz_repr::SqlColumnType],
516            ) -> mz_repr::SqlColumnType {
517                use mz_repr::AsColumnType;
518                let output = #output_type;
519                let propagates_nulls =
520                    crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
521                let nullable = output.nullable;
522                // The output is nullable if it is nullable by itself
523                // or the input is nullable and this function
524                // propagates nulls
525                let inputs_nullable = input_types.iter().any(|it| it.nullable);
526                let is_null = nullable || (propagates_nulls && inputs_nullable);
527                output.nullable(is_null)
528            }
529
530            #could_error_fn
531            #introduces_nulls_fn
532            #is_infix_op_fn
533            #is_monotone_fn
534            #negate_fn
535            #propagates_nulls_fn
536        }
537
538        impl std::fmt::Display for #struct_name {
539            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
540                f.write_str(#name)
541            }
542        }
543
544        #func
545
546    };
547    Ok(result)
548}