Skip to main content

mz_expr_derive_impl/
sqlfunc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16/// Modifiers passed as key-value pairs to the `#[sqlfunc]` macro.
17#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19    /// An optional expression that evaluates to a boolean indicating whether the function is
20    /// monotone with respect to its arguments. Defined for unary and binary functions.
21    is_monotone: Option<Expr>,
22    /// The SQL name for the function. Applies to all functions.
23    sqlname: Option<SqlName>,
24    /// Whether the function preserves uniqueness. Applies to unary functions.
25    preserves_uniqueness: Option<Expr>,
26    /// The inverse of the function, if it exists. Applies to unary functions.
27    inverse: Option<Expr>,
28    /// The negated function, if it exists. Applies to binary functions.
29    negate: Option<Expr>,
30    /// Whether the function is an infix operator. Applies to binary functions, and needs to
31    /// be specified.
32    is_infix_op: Option<Expr>,
33    /// The output type of the function, if it cannot be inferred. Applies to all functions.
34    output_type: Option<syn::Path>,
35    /// The output type of the function as an expression. Applies to binary and variadic functions.
36    output_type_expr: Option<Expr>,
37    /// Optional expression evaluating to a boolean indicating whether the function could error.
38    /// Applies to all functions.
39    could_error: Option<Expr>,
40    /// Whether the function propagates nulls. Applies to binary and variadic functions.
41    propagates_nulls: Option<Expr>,
42    /// Whether the function introduces nulls. Applies to all functions.
43    introduces_nulls: Option<Expr>,
44    /// Whether the function is associative. Applies to variadic functions.
45    is_associative: Option<Expr>,
46    /// Whether to generate a snapshot test for the function. Defaults to false.
47    test: Option<bool>,
48}
49
50/// A name for the SQL function. It can be either a literal or a macro, thus we
51/// can't use `String` or `syn::Expr` directly.
52#[derive(Debug)]
53enum SqlName {
54    /// A literal string.
55    Literal(syn::Lit),
56    /// A macro expression.
57    Macro(syn::ExprMacro),
58}
59
60impl quote::ToTokens for SqlName {
61    fn to_tokens(&self, tokens: &mut TokenStream) {
62        let name = match self {
63            SqlName::Literal(lit) => quote! { #lit },
64            SqlName::Macro(mac) => quote! { #mac },
65        };
66        tokens.extend(name);
67    }
68}
69
70impl darling::FromMeta for SqlName {
71    fn from_value(value: &Lit) -> darling::Result<Self> {
72        Ok(Self::Literal(value.clone()))
73    }
74    fn from_expr(expr: &Expr) -> darling::Result<Self> {
75        match expr {
76            Expr::Lit(lit) => Self::from_value(&lit.lit),
77            Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
78            // Syn sometimes inserts groups, see `FromMeta::from_expr` for
79            // details.
80            Expr::Group(mac) => Self::from_expr(&mac.expr),
81            _ => Err(darling::Error::unexpected_expr_type(expr)),
82        }
83    }
84}
85
86/// Implementation for the `#[sqlfunc]` macro. The first parameter is the attribute
87/// arguments, the second is the function body. The third parameter indicates
88/// whether to include the test function in the output.
89///
90/// The feature `test` must be enabled to include the test function.
91pub fn sqlfunc(
92    attr: TokenStream,
93    item: TokenStream,
94    include_test: bool,
95) -> darling::Result<TokenStream> {
96    let mut attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
97
98    // Check if the first attribute arg is a bare Path (struct name for variadic).
99    let struct_ty = match attr_args.first() {
100        Some(darling::ast::NestedMeta::Meta(syn::Meta::Path(_))) => {
101            let darling::ast::NestedMeta::Meta(syn::Meta::Path(path)) = attr_args.remove(0) else {
102                unreachable!()
103            };
104            Some(path)
105        }
106        _ => None,
107    };
108
109    let modifiers = Modifiers::from_list(&attr_args).unwrap();
110    let generate_tests = modifiers.test.unwrap_or(false);
111    let func = syn::parse2::<syn::ItemFn>(item.clone())?;
112
113    let tokens = match determine_arity(&func) {
114        Arity::Nullary => Err(darling::Error::custom("Nullary functions not supported")),
115        Arity::Unary { arena: false } => unary_func(&func, modifiers),
116        Arity::Unary { arena: true } => Err(darling::Error::custom(
117            "Unary functions do not yet support RowArena.",
118        )),
119        Arity::Binary { arena } => binary_func(&func, modifiers, arena),
120        Arity::Variadic { arena, has_self } => {
121            variadic_func(&func, modifiers, struct_ty, arena, has_self)
122        }
123    }?;
124
125    let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
126
127    Ok(quote! {
128        #tokens
129        #test
130    })
131}
132
133#[cfg(any(feature = "test", test))]
134fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
135    let attr = attr.to_string();
136    let item = item.to_string();
137    let test_name = Ident::new(&format!("test_{}", name), name.span());
138    let fn_name = name.to_string();
139
140    quote! {
141        #[cfg(test)]
142        #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
143        #[mz_ore::test]
144        fn #test_name() {
145            let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
146            insta::assert_snapshot!(#fn_name, output, &input);
147        }
148    }
149}
150
151#[cfg(not(any(feature = "test", test)))]
152fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
153    quote! {}
154}
155
156/// Checks if the last parameter of the function is a `&RowArena`.
157fn last_is_arena(func: &syn::ItemFn) -> bool {
158    func.sig.inputs.last().map_or(false, |last| {
159        if let syn::FnArg::Typed(pat) = last {
160            if let syn::Type::Reference(reference) = &*pat.ty {
161                if let syn::Type::Path(path) = &*reference.elem {
162                    return path.path.is_ident("RowArena");
163                }
164            }
165        }
166        false
167    })
168}
169
170/// Arity classification for a function annotated with `#[sqlfunc]`.
171enum Arity {
172    Nullary,
173    Unary { arena: bool },
174    Binary { arena: bool },
175    Variadic { arena: bool, has_self: bool },
176}
177
178/// Checks whether a parameter's type is `Variadic<...>` or `OptionalArg<...>`,
179/// which indicates the function should be treated as variadic regardless of
180/// parameter count.
181fn is_variadic_arg(arg: &syn::FnArg) -> bool {
182    if let syn::FnArg::Typed(pat) = arg {
183        if let syn::Type::Path(path) = &*pat.ty {
184            if let Some(segment) = path.path.segments.last() {
185                let ident = segment.ident.to_string();
186                return ident == "Variadic" || ident == "OptionalArg";
187            }
188        }
189    }
190    false
191}
192
193/// Determines the arity of a function annotated with `#[sqlfunc]`.
194///
195/// Accounts for `&self` receivers, trailing `&RowArena` parameters, and
196/// parameter types like `Variadic<T>` or `OptionalArg<T>` that indicate
197/// variadic dispatch.
198fn determine_arity(func: &syn::ItemFn) -> Arity {
199    let arena = last_is_arena(func);
200    let has_self = matches!(func.sig.inputs.first(), Some(syn::FnArg::Receiver(_)));
201
202    let mut effective_count = func.sig.inputs.len();
203    if arena {
204        effective_count -= 1;
205    }
206    if has_self {
207        effective_count -= 1;
208    }
209
210    // Check if any effective parameter uses a variadic-typed wrapper.
211    let start = if has_self { 1 } else { 0 };
212    let end = if arena {
213        func.sig.inputs.len() - 1
214    } else {
215        func.sig.inputs.len()
216    };
217    let has_variadic_param = func
218        .sig
219        .inputs
220        .iter()
221        .skip(start)
222        .take(end - start)
223        .any(is_variadic_arg);
224
225    if has_variadic_param || effective_count >= 3 {
226        Arity::Variadic { arena, has_self }
227    } else {
228        match effective_count {
229            0 => Arity::Nullary,
230            1 => Arity::Unary { arena },
231            2 => Arity::Binary { arena },
232            _ => unreachable!(),
233        }
234    }
235}
236
237/// Convert an identifier to a camel-cased identifier.
238/// Checks if a parameter type accepts NULL.
239///
240/// `Option<T>` always accepts NULL. `OptionalArg<T>` delegates to `T`.
241/// `Datum` accepts NULL (it passes through raw values including null).
242/// Everything else (references, concrete types) rejects NULL.
243fn is_nullable_type(ty: &syn::Type) -> bool {
244    if let syn::Type::Path(type_path) = ty {
245        if let Some(last_segment) = type_path.path.segments.last() {
246            let ident = &last_segment.ident;
247            if ident == "Option" || ident == "Datum" {
248                return true;
249            }
250            if ident == "OptionalArg" {
251                // OptionalArg<T> delegates nullability to T.
252                if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
253                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
254                        return is_nullable_type(inner_ty);
255                    }
256                }
257                return false;
258            }
259        }
260    }
261    false
262}
263
264/// Checks if a type is `Variadic<T>`.
265fn is_variadic_type(ty: &syn::Type) -> bool {
266    if let syn::Type::Path(type_path) = ty {
267        if let Some(last_segment) = type_path.path.segments.last() {
268            return last_segment.ident == "Variadic";
269        }
270    }
271    false
272}
273
274/// For a `Variadic<T>` type, checks if `T` accepts NULL.
275fn variadic_element_is_nullable(ty: &syn::Type) -> bool {
276    if let syn::Type::Path(type_path) = ty {
277        if let Some(last_segment) = type_path.path.segments.last() {
278            if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
279                if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
280                    return is_nullable_type(inner_ty);
281                }
282            }
283        }
284    }
285    false
286}
287
288/// Generates per-position nullability checks for non-nullable parameters.
289///
290/// For each parameter that rejects NULL (not `Option`, not `OptionalArg<Option<..>>`),
291/// generates a check that the corresponding input position is nullable. For `Variadic<T>`
292/// with non-nullable `T`, generates a check over all remaining input positions.
293fn non_nullable_position_checks(param_types: &[syn::Type]) -> Vec<TokenStream> {
294    let mut checks = Vec::new();
295    for (i, ty) in param_types.iter().enumerate() {
296        if is_variadic_type(ty) {
297            if !variadic_element_is_nullable(ty) {
298                checks.push(quote! { || input_types.iter().skip(#i).any(|t| t.nullable) });
299            }
300        } else if !is_nullable_type(ty) {
301            checks.push(quote! { || input_types.get(#i).map_or(false, |t| t.nullable) });
302        }
303    }
304    checks
305}
306
307fn camel_case(ident: &Ident) -> Ident {
308    let mut result = String::new();
309    let mut capitalize_next = true;
310    for c in ident.to_string().chars() {
311        if c == '_' {
312            capitalize_next = true;
313        } else if capitalize_next {
314            result.push(c.to_ascii_uppercase());
315            capitalize_next = false;
316        } else {
317            result.push(c);
318        }
319    }
320    Ident::new(&result, ident.span())
321}
322
323/// Determines the argument type of the nth argument of the function.
324///
325/// Adds a lifetime `'a` to the argument type if it is a reference type.
326///
327/// Panics if the function has fewer than `nth` arguments. Returns an error if
328/// the parameter is a `self` receiver.
329fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
330    match &arg.sig.inputs[nth] {
331        syn::FnArg::Typed(pat) => {
332            // Patch lifetimes to be 'a if reference
333            if let syn::Type::Reference(r) = &*pat.ty {
334                if r.lifetime.is_none() {
335                    let ty = syn::Type::Reference(syn::TypeReference {
336                        lifetime: Some(Lifetime::new("'a", r.span())),
337                        ..r.clone()
338                    });
339                    return Ok(ty);
340                }
341            }
342            Ok((*pat.ty).clone())
343        }
344        _ => Err(syn::Error::new(
345            arg.sig.inputs[nth].span(),
346            "Unsupported argument type",
347        )),
348    }
349}
350
351/// Recursively patches lifetimes in a type, adding `'a` to references without a lifetime
352/// and recursing into generic arguments and tuples.
353fn patch_lifetimes(ty: &syn::Type) -> syn::Type {
354    match ty {
355        syn::Type::Reference(r) => {
356            let elem = Box::new(patch_lifetimes(&r.elem));
357            if r.lifetime.is_none() {
358                syn::Type::Reference(syn::TypeReference {
359                    lifetime: Some(Lifetime::new("'a", r.span())),
360                    elem,
361                    ..r.clone()
362                })
363            } else {
364                syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
365            }
366        }
367        syn::Type::Tuple(t) => {
368            let elems = t.elems.iter().map(patch_lifetimes).collect();
369            syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
370        }
371        syn::Type::Path(p) => {
372            let mut p = p.clone();
373            for segment in &mut p.path.segments {
374                if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
375                    for arg in &mut args.args {
376                        if let syn::GenericArgument::Type(ty) = arg {
377                            *ty = patch_lifetimes(ty);
378                        }
379                    }
380                }
381            }
382            syn::Type::Path(p)
383        }
384        _ => ty.clone(),
385    }
386}
387
388/// Determine the output type for a function. Returns an error if the function
389/// does not return a value.
390fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
391    match &arg.sig.output {
392        syn::ReturnType::Type(_, ty) => Ok(&*ty),
393        syn::ReturnType::Default => Err(syn::Error::new(
394            arg.sig.output.span(),
395            "Function needs to return a value",
396        )),
397    }
398}
399
400/// Produce a `EagerUnaryFunc` implementation.
401fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
402    let fn_name = &func.sig.ident;
403    let struct_name = camel_case(&func.sig.ident);
404    let input_ty = arg_type(func, 0)?;
405    let output_ty = output_type(func)?;
406    let Modifiers {
407        is_monotone,
408        sqlname,
409        preserves_uniqueness,
410        inverse,
411        is_infix_op,
412        output_type,
413        output_type_expr,
414        negate,
415        could_error,
416        propagates_nulls,
417        introduces_nulls,
418        is_associative,
419        test: _,
420    } = modifiers;
421
422    if is_infix_op.is_some() {
423        return Err(darling::Error::unknown_field(
424            "is_infix_op not supported for unary functions",
425        ));
426    }
427    if output_type.is_some() && output_type_expr.is_some() {
428        return Err(darling::Error::unknown_field(
429            "output_type and output_type_expr cannot be used together",
430        ));
431    }
432    if output_type_expr.is_some() && introduces_nulls.is_none() {
433        return Err(darling::Error::unknown_field(
434            "output_type_expr requires introduces_nulls",
435        ));
436    }
437    if negate.is_some() {
438        return Err(darling::Error::unknown_field(
439            "negate not supported for unary functions",
440        ));
441    }
442    if propagates_nulls.is_some() {
443        return Err(darling::Error::unknown_field(
444            "propagates_nulls not supported for unary functions",
445        ));
446    }
447    if is_associative.is_some() {
448        return Err(darling::Error::unknown_field(
449            "is_associative not supported for unary functions",
450        ));
451    }
452
453    let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
454        quote! {
455            fn preserves_uniqueness(&self) -> bool {
456                #preserves_uniqueness
457            }
458        }
459    });
460
461    let inverse_fn = inverse.as_ref().map(|inverse| {
462        quote! {
463            fn inverse(&self) -> Option<crate::UnaryFunc> {
464                #inverse
465            }
466        }
467    });
468
469    let is_monotone_fn = is_monotone.map(|is_monotone| {
470        quote! {
471            fn is_monotone(&self) -> bool {
472                #is_monotone
473            }
474        }
475    });
476
477    let name = sqlname
478        .as_ref()
479        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
480
481    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
482        let introduces_nulls_fn = quote! {
483            fn introduces_nulls(&self) -> bool {
484                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
485            }
486        };
487        let output_type = quote! { <#output_type>::as_column_type() };
488        (output_type, Some(introduces_nulls_fn))
489    } else {
490        (quote! { Self::Output::as_column_type() }, None)
491    };
492
493    if let Some(output_type_expr) = output_type_expr {
494        output_type = quote! { #output_type_expr };
495    }
496
497    if let Some(introduces_nulls) = introduces_nulls {
498        introduces_nulls_fn = Some(quote! {
499            fn introduces_nulls(&self) -> bool {
500                #introduces_nulls
501            }
502        });
503    }
504
505    let could_error_fn = could_error.map(|could_error| {
506        quote! {
507            fn could_error(&self) -> bool {
508                #could_error
509            }
510        }
511    });
512
513    let result = quote! {
514        #[derive(
515            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
516            Debug, Eq, PartialEq, serde::Serialize,
517            serde::Deserialize, Hash, mz_lowertest::MzReflect,
518        )]
519        pub struct #struct_name;
520
521        impl crate::func::EagerUnaryFunc for #struct_name {
522            type Input<'a> = #input_ty;
523            type Output<'a> = #output_ty;
524
525            fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
526                #fn_name(a)
527            }
528
529            fn output_sql_type(
530                &self,
531                input_type: mz_repr::SqlColumnType
532            ) -> mz_repr::SqlColumnType {
533                use mz_repr::AsColumnType;
534                let output = #output_type;
535                let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
536                let nullable = output.nullable;
537                // The output is nullable if it is nullable by itself or the input is nullable
538                // and this function propagates nulls
539                output.nullable(nullable || (propagates_nulls && input_type.nullable))
540            }
541
542            #could_error_fn
543            #introduces_nulls_fn
544            #inverse_fn
545            #is_monotone_fn
546            #preserves_uniqueness_fn
547        }
548
549        impl std::fmt::Display for #struct_name {
550            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
551                f.write_str(#name)
552            }
553        }
554
555        #func
556    };
557    Ok(result)
558}
559
560/// Produce a `EagerBinaryFunc` implementation.
561fn binary_func(
562    func: &syn::ItemFn,
563    modifiers: Modifiers,
564    arena: bool,
565) -> darling::Result<TokenStream> {
566    let fn_name = &func.sig.ident;
567    let struct_name = camel_case(&func.sig.ident);
568    let input1_ty = arg_type(func, 0)?;
569    let input2_ty = arg_type(func, 1)?;
570    let output_ty = output_type(func)?;
571
572    let Modifiers {
573        is_monotone,
574        sqlname,
575        preserves_uniqueness,
576        inverse,
577        is_infix_op,
578        output_type,
579        output_type_expr,
580        negate,
581        could_error,
582        propagates_nulls,
583        introduces_nulls,
584        is_associative,
585        test: _,
586    } = modifiers;
587
588    if preserves_uniqueness.is_some() {
589        return Err(darling::Error::unknown_field(
590            "preserves_uniqueness not supported for binary functions",
591        ));
592    }
593    if inverse.is_some() {
594        return Err(darling::Error::unknown_field(
595            "inverse not supported for binary functions",
596        ));
597    }
598    if output_type.is_some() && output_type_expr.is_some() {
599        return Err(darling::Error::unknown_field(
600            "output_type and output_type_expr cannot be used together",
601        ));
602    }
603    if output_type_expr.is_some() && introduces_nulls.is_none() {
604        return Err(darling::Error::unknown_field(
605            "output_type_expr requires introduces_nulls",
606        ));
607    }
608    if is_associative.is_some() {
609        return Err(darling::Error::unknown_field(
610            "is_associative not supported for binary functions",
611        ));
612    }
613
614    let negate_fn = negate.map(|negate| {
615        quote! {
616            fn negate(&self) -> Option<crate::BinaryFunc> {
617                #negate
618            }
619        }
620    });
621
622    let is_monotone_fn = is_monotone.map(|is_monotone| {
623        quote! {
624            fn is_monotone(&self) -> (bool, bool) {
625                #is_monotone
626            }
627        }
628    });
629
630    let name = sqlname
631        .as_ref()
632        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
633
634    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
635        let introduces_nulls_fn = quote! {
636            fn introduces_nulls(&self) -> bool {
637                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
638            }
639        };
640        let output_type = quote! { <#output_type>::as_column_type() };
641        (output_type, Some(introduces_nulls_fn))
642    } else {
643        (quote! { Self::Output::as_column_type() }, None)
644    };
645
646    if let Some(output_type_expr) = output_type_expr {
647        output_type = quote! { #output_type_expr };
648    }
649
650    if let Some(introduces_nulls) = introduces_nulls {
651        introduces_nulls_fn = Some(quote! {
652            fn introduces_nulls(&self) -> bool {
653                #introduces_nulls
654            }
655        });
656    }
657
658    let arena = if arena {
659        quote! { , temp_storage }
660    } else {
661        quote! {}
662    };
663
664    let could_error_fn = could_error.map(|could_error| {
665        quote! {
666            fn could_error(&self) -> bool {
667                #could_error
668            }
669        }
670    });
671
672    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
673        quote! {
674            fn is_infix_op(&self) -> bool {
675                #is_infix_op
676            }
677        }
678    });
679
680    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
681        quote! {
682            fn propagates_nulls(&self) -> bool {
683                #propagates_nulls
684            }
685        }
686    });
687
688    // Per-position checks: for each non-nullable parameter, check if
689    // the corresponding input column is nullable.
690    let binary_non_nullable_checks =
691        non_nullable_position_checks(&[input1_ty.clone(), input2_ty.clone()]);
692
693    let result = quote! {
694        #[derive(
695            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
696            Debug, Eq, PartialEq, serde::Serialize,
697            serde::Deserialize, Hash, mz_lowertest::MzReflect,
698        )]
699        pub struct #struct_name;
700
701        impl crate::func::binary::EagerBinaryFunc for #struct_name {
702            type Input<'a> = (#input1_ty, #input2_ty);
703            type Output<'a> = #output_ty;
704
705            fn call<'a>(
706                &self,
707                (a, b): Self::Input<'a>,
708                temp_storage: &'a mz_repr::RowArena
709            ) -> Self::Output<'a> {
710                #fn_name(a, b #arena)
711            }
712
713            fn output_sql_type(
714                &self,
715                input_types: &[mz_repr::SqlColumnType],
716            ) -> mz_repr::SqlColumnType {
717                use mz_repr::AsColumnType;
718                let output = #output_type;
719                let propagates_nulls =
720                    crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
721                let nullable = output.nullable;
722                // The output is nullable if:
723                // 1. The function introduces nulls (output.nullable), or
724                // 2. A non-nullable parameter's input is nullable (will reject
725                //    NULL at runtime via try_from_iter), or
726                // 3. propagates_nulls is true and any input is nullable
727                //    (optimizer short-circuits all-NULL inputs)
728                let non_nullable_input_is_nullable =
729                    false #(#binary_non_nullable_checks)*;
730                let inputs_nullable = input_types.iter().any(|it| it.nullable);
731                let is_null = nullable
732                    || non_nullable_input_is_nullable
733                    || (propagates_nulls && inputs_nullable);
734                output.nullable(is_null)
735            }
736
737            #could_error_fn
738            #introduces_nulls_fn
739            #is_infix_op_fn
740            #is_monotone_fn
741            #negate_fn
742            #propagates_nulls_fn
743        }
744
745        impl std::fmt::Display for #struct_name {
746            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
747                f.write_str(#name)
748            }
749        }
750
751        #func
752
753    };
754    Ok(result)
755}
756
757/// Produce an `EagerVariadicFunc` implementation.
758///
759/// Two modes based on whether the function has a `&self` receiver:
760/// * `&self` present: struct defined externally, generates method impl + trait impl + Display
761/// * No `&self`: generates unit struct + trait impl + Display + preserves original function
762fn variadic_func(
763    func: &syn::ItemFn,
764    modifiers: Modifiers,
765    struct_ty: Option<syn::Path>,
766    arena: bool,
767    has_self: bool,
768) -> darling::Result<TokenStream> {
769    let fn_name = &func.sig.ident;
770    let output_ty = output_type(func)?;
771    let struct_name = struct_ty
772        .as_ref()
773        .and_then(|ty| ty.segments.last())
774        .map_or_else(|| camel_case(fn_name), |seg| seg.ident.clone());
775
776    let Modifiers {
777        is_monotone,
778        sqlname,
779        preserves_uniqueness,
780        inverse,
781        is_infix_op,
782        output_type,
783        output_type_expr,
784        negate,
785        could_error,
786        propagates_nulls,
787        introduces_nulls,
788        is_associative,
789        test: _,
790    } = modifiers;
791
792    // Reject modifiers that don't apply to variadic functions.
793    if preserves_uniqueness.is_some() {
794        return Err(darling::Error::unknown_field(
795            "preserves_uniqueness not supported for variadic functions",
796        ));
797    }
798    if inverse.is_some() {
799        return Err(darling::Error::unknown_field(
800            "inverse not supported for variadic functions",
801        ));
802    }
803    if negate.is_some() {
804        return Err(darling::Error::unknown_field(
805            "negate not supported for variadic functions",
806        ));
807    }
808    if output_type.is_some() && output_type_expr.is_some() {
809        return Err(darling::Error::unknown_field(
810            "output_type and output_type_expr cannot be used together",
811        ));
812    }
813    if output_type_expr.is_some() && introduces_nulls.is_none() {
814        return Err(darling::Error::unknown_field(
815            "output_type_expr requires introduces_nulls",
816        ));
817    }
818
819    // Collect input parameters (skip &self, skip &RowArena).
820    let start = if has_self { 1 } else { 0 };
821    let end = if arena {
822        func.sig.inputs.len() - 1
823    } else {
824        func.sig.inputs.len()
825    };
826    let input_params: Vec<&syn::FnArg> = func
827        .sig
828        .inputs
829        .iter()
830        .skip(start)
831        .take(end - start)
832        .collect();
833
834    if input_params.is_empty() {
835        return Err(darling::Error::custom(
836            "variadic function must have at least one input parameter",
837        ));
838    }
839
840    // Extract parameter names and types.
841    let mut param_names = Vec::new();
842    let mut param_types = Vec::new();
843    for param in &input_params {
844        match param {
845            syn::FnArg::Typed(pat) => {
846                if let syn::Pat::Ident(ident) = &*pat.pat {
847                    param_names.push(ident.ident.clone());
848                } else {
849                    return Err(
850                        darling::Error::custom("unsupported parameter pattern").with_span(&pat.pat)
851                    );
852                }
853                param_types.push(patch_lifetimes(&pat.ty));
854            }
855            _ => {
856                return Err(darling::Error::custom("unexpected self parameter"));
857            }
858        }
859    }
860
861    // Build input type: single param = bare type, multiple = tuple.
862    let input_type: syn::Type = if param_types.len() == 1 {
863        param_types[0].clone()
864    } else {
865        syn::parse_quote! { (#(#param_types),*) }
866    };
867
868    // Build destructure pattern for call.
869    let destructure = if param_names.len() == 1 {
870        let name = &param_names[0];
871        quote! { #name }
872    } else {
873        quote! { (#(#param_names),*) }
874    };
875
876    let arena_arg = if arena {
877        quote! { , temp_storage }
878    } else {
879        quote! {}
880    };
881
882    let call_expr = if has_self {
883        quote! { self.#fn_name(#(#param_names),* #arena_arg) }
884    } else {
885        quote! { #fn_name(#(#param_names),* #arena_arg) }
886    };
887
888    // Build modifier functions.
889    let name = sqlname
890        .as_ref()
891        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
892
893    let (mut output_type_code, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
894        let introduces_nulls_fn = quote! {
895            fn introduces_nulls(&self) -> bool {
896                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
897            }
898        };
899        let output_type_code = quote! { <#output_type>::as_column_type() };
900        (output_type_code, Some(introduces_nulls_fn))
901    } else {
902        (quote! { Self::Output::as_column_type() }, None)
903    };
904
905    if let Some(output_type_expr) = output_type_expr {
906        output_type_code = quote! { #output_type_expr };
907    }
908
909    if let Some(introduces_nulls) = introduces_nulls {
910        introduces_nulls_fn = Some(quote! {
911            fn introduces_nulls(&self) -> bool {
912                #introduces_nulls
913            }
914        });
915    }
916
917    let could_error_fn = could_error.map(|could_error| {
918        quote! {
919            fn could_error(&self) -> bool {
920                #could_error
921            }
922        }
923    });
924
925    let is_monotone_fn = is_monotone.map(|is_monotone| {
926        quote! {
927            fn is_monotone(&self) -> bool {
928                #is_monotone
929            }
930        }
931    });
932
933    let is_associative_fn = is_associative.map(|is_associative| {
934        quote! {
935            fn is_associative(&self) -> bool {
936                #is_associative
937            }
938        }
939    });
940
941    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
942        quote! {
943            fn is_infix_op(&self) -> bool {
944                #is_infix_op
945            }
946        }
947    });
948
949    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
950        quote! {
951            fn propagates_nulls(&self) -> bool {
952                #propagates_nulls
953            }
954        }
955    });
956
957    // Per-position checks: for each non-nullable parameter, check if
958    // the corresponding input column is nullable.
959    let non_nullable_checks = non_nullable_position_checks(&param_types);
960
961    let trait_impl = quote! {
962        impl crate::func::variadic::EagerVariadicFunc for #struct_name {
963            type Input<'a> = #input_type;
964            type Output<'a> = #output_ty;
965
966            fn call<'a>(
967                &self,
968                #destructure: Self::Input<'a>,
969                temp_storage: &'a mz_repr::RowArena,
970            ) -> Self::Output<'a> {
971                #call_expr
972            }
973
974            fn output_type(
975                &self,
976                input_types: &[mz_repr::SqlColumnType],
977            ) -> mz_repr::SqlColumnType {
978                use mz_repr::AsColumnType;
979                let output = #output_type_code;
980                let propagates_nulls =
981                    crate::func::variadic::EagerVariadicFunc::propagates_nulls(self);
982                let nullable = output.nullable;
983                // The output is nullable if:
984                // 1. The function introduces nulls (output.nullable), or
985                // 2. A non-nullable parameter's input is nullable (will reject
986                //    NULL at runtime via try_from_iter), or
987                // 3. propagates_nulls is true and any input is nullable
988                //    (optimizer short-circuits all-NULL inputs)
989                let non_nullable_input_is_nullable =
990                    false #(#non_nullable_checks)*;
991                let inputs_nullable = input_types.iter().any(|it| it.nullable);
992                output.nullable(
993                    nullable
994                    || non_nullable_input_is_nullable
995                    || (propagates_nulls && inputs_nullable)
996                )
997            }
998
999            #could_error_fn
1000            #introduces_nulls_fn
1001            #is_infix_op_fn
1002            #is_monotone_fn
1003            #is_associative_fn
1004            #propagates_nulls_fn
1005        }
1006    };
1007
1008    let display_impl = quote! {
1009        impl std::fmt::Display for #struct_name {
1010            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1011                f.write_str(#name)
1012            }
1013        }
1014    };
1015
1016    let result = if has_self {
1017        // External struct: generate method impl + trait impl + Display.
1018        quote! {
1019            impl #struct_name {
1020                #func
1021            }
1022            #trait_impl
1023            #display_impl
1024        }
1025    } else {
1026        // Unit struct: generate struct + trait impl + Display + original function.
1027        quote! {
1028            #[derive(
1029                proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1030                Debug, Eq, PartialEq, serde::Serialize,
1031                serde::Deserialize, Hash, mz_lowertest::MzReflect,
1032            )]
1033            pub struct #struct_name;
1034
1035            #trait_impl
1036            #display_impl
1037
1038            #func
1039        }
1040    };
1041
1042    Ok(result)
1043}