Skip to main content

mz_expr_derive_impl/
sqlfunc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16/// Modifiers passed as key-value pairs to the `#[sqlfunc]` macro.
17#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19    /// An optional expression that evaluates to a boolean indicating whether the function is
20    /// monotone with respect to its arguments. Defined for unary and binary functions.
21    is_monotone: Option<Expr>,
22    /// The SQL name for the function. Applies to all functions.
23    sqlname: Option<SqlName>,
24    /// Whether the function preserves uniqueness. Applies to unary functions.
25    preserves_uniqueness: Option<Expr>,
26    /// The inverse of the function, if it exists. Applies to unary functions.
27    inverse: Option<Expr>,
28    /// The negated function, if it exists. Applies to binary functions.
29    negate: Option<Expr>,
30    /// Whether the function is an infix operator. Applies to binary functions, and needs to
31    /// be specified.
32    is_infix_op: Option<Expr>,
33    /// The output type of the function, if it cannot be inferred. Applies to all functions.
34    output_type: Option<syn::Path>,
35    /// The output type of the function as an expression. Applies to binary and variadic functions.
36    output_type_expr: Option<Expr>,
37    /// Optional expression evaluating to a boolean indicating whether the function could error.
38    /// Applies to all functions.
39    could_error: Option<Expr>,
40    /// Whether the function propagates nulls. Applies to binary and variadic functions.
41    propagates_nulls: Option<Expr>,
42    /// Whether the function introduces nulls. Applies to all functions.
43    introduces_nulls: Option<Expr>,
44    /// Whether the function is associative. Applies to variadic functions.
45    is_associative: Option<Expr>,
46    /// Whether the function is a noop cast. Applies to unary functions.
47    is_eliminable_cast: Option<Expr>,
48    /// Whether to generate a snapshot test for the function. Defaults to false.
49    test: Option<bool>,
50}
51
52/// A name for the SQL function. It can be either a literal or a macro, thus we
53/// can't use `String` or `syn::Expr` directly.
54#[derive(Debug)]
55enum SqlName {
56    /// A literal string.
57    Literal(syn::Lit),
58    /// A macro expression.
59    Macro(syn::ExprMacro),
60}
61
62impl quote::ToTokens for SqlName {
63    fn to_tokens(&self, tokens: &mut TokenStream) {
64        let name = match self {
65            SqlName::Literal(lit) => quote! { #lit },
66            SqlName::Macro(mac) => quote! { #mac },
67        };
68        tokens.extend(name);
69    }
70}
71
72impl darling::FromMeta for SqlName {
73    fn from_value(value: &Lit) -> darling::Result<Self> {
74        Ok(Self::Literal(value.clone()))
75    }
76    fn from_expr(expr: &Expr) -> darling::Result<Self> {
77        match expr {
78            Expr::Lit(lit) => Self::from_value(&lit.lit),
79            Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
80            // Syn sometimes inserts groups, see `FromMeta::from_expr` for
81            // details.
82            Expr::Group(mac) => Self::from_expr(&mac.expr),
83            _ => Err(darling::Error::unexpected_expr_type(expr)),
84        }
85    }
86}
87
88/// Implementation for the `#[sqlfunc]` macro. The first parameter is the attribute
89/// arguments, the second is the function body. The third parameter indicates
90/// whether to include the test function in the output.
91///
92/// The feature `test` must be enabled to include the test function.
93pub fn sqlfunc(
94    attr: TokenStream,
95    item: TokenStream,
96    include_test: bool,
97) -> darling::Result<TokenStream> {
98    let mut attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
99
100    // Check if the first attribute arg is a bare Path (struct name for variadic).
101    let struct_ty = match attr_args.first() {
102        Some(darling::ast::NestedMeta::Meta(syn::Meta::Path(_))) => {
103            let darling::ast::NestedMeta::Meta(syn::Meta::Path(path)) = attr_args.remove(0) else {
104                unreachable!()
105            };
106            Some(path)
107        }
108        _ => None,
109    };
110
111    let modifiers = Modifiers::from_list(&attr_args).unwrap();
112    let generate_tests = modifiers.test.unwrap_or(false);
113    let func = syn::parse2::<syn::ItemFn>(item.clone())?;
114
115    let tokens = match determine_arity(&func) {
116        Arity::Nullary => Err(darling::Error::custom("Nullary functions not supported")),
117        Arity::Unary { arena: false } => unary_func(&func, modifiers),
118        Arity::Unary { arena: true } => Err(darling::Error::custom(
119            "Unary functions do not yet support RowArena.",
120        )),
121        Arity::Binary { arena } => binary_func(&func, modifiers, arena),
122        Arity::Variadic { arena, has_self } => {
123            variadic_func(&func, modifiers, struct_ty, arena, has_self)
124        }
125    }?;
126
127    let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
128
129    Ok(quote! {
130        #tokens
131        #test
132    })
133}
134
135#[cfg(any(feature = "test", test))]
136fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
137    let attr = attr.to_string();
138    let item = item.to_string();
139    let test_name = Ident::new(&format!("test_{}", name), name.span());
140    let fn_name = name.to_string();
141
142    quote! {
143        #[cfg(test)]
144        #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri
145        #[mz_ore::test]
146        fn #test_name() {
147            let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
148            insta::assert_snapshot!(#fn_name, output, &input);
149        }
150    }
151}
152
153#[cfg(not(any(feature = "test", test)))]
154fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
155    quote! {}
156}
157
158/// Checks if the last parameter of the function is a `&RowArena`.
159fn last_is_arena(func: &syn::ItemFn) -> bool {
160    func.sig.inputs.last().map_or(false, |last| {
161        if let syn::FnArg::Typed(pat) = last {
162            if let syn::Type::Reference(reference) = &*pat.ty {
163                if let syn::Type::Path(path) = &*reference.elem {
164                    return path.path.is_ident("RowArena");
165                }
166            }
167        }
168        false
169    })
170}
171
172/// Arity classification for a function annotated with `#[sqlfunc]`.
173enum Arity {
174    Nullary,
175    Unary { arena: bool },
176    Binary { arena: bool },
177    Variadic { arena: bool, has_self: bool },
178}
179
180/// Checks whether a parameter's type is `Variadic<...>` or `OptionalArg<...>`,
181/// which indicates the function should be treated as variadic regardless of
182/// parameter count.
183fn is_variadic_arg(arg: &syn::FnArg) -> bool {
184    if let syn::FnArg::Typed(pat) = arg {
185        if let syn::Type::Path(path) = &*pat.ty {
186            if let Some(segment) = path.path.segments.last() {
187                let ident = segment.ident.to_string();
188                return ident == "Variadic" || ident == "OptionalArg";
189            }
190        }
191    }
192    false
193}
194
195/// Determines the arity of a function annotated with `#[sqlfunc]`.
196///
197/// Accounts for `&self` receivers, trailing `&RowArena` parameters, and
198/// parameter types like `Variadic<T>` or `OptionalArg<T>` that indicate
199/// variadic dispatch.
200fn determine_arity(func: &syn::ItemFn) -> Arity {
201    let arena = last_is_arena(func);
202    let has_self = matches!(func.sig.inputs.first(), Some(syn::FnArg::Receiver(_)));
203
204    let mut effective_count = func.sig.inputs.len();
205    if arena {
206        effective_count -= 1;
207    }
208    if has_self {
209        effective_count -= 1;
210    }
211
212    // Check if any effective parameter uses a variadic-typed wrapper.
213    let start = if has_self { 1 } else { 0 };
214    let end = if arena {
215        func.sig.inputs.len() - 1
216    } else {
217        func.sig.inputs.len()
218    };
219    let has_variadic_param = func
220        .sig
221        .inputs
222        .iter()
223        .skip(start)
224        .take(end - start)
225        .any(is_variadic_arg);
226
227    if has_variadic_param || effective_count >= 3 {
228        Arity::Variadic { arena, has_self }
229    } else {
230        match effective_count {
231            0 => Arity::Nullary,
232            1 => Arity::Unary { arena },
233            2 => Arity::Binary { arena },
234            _ => unreachable!(),
235        }
236    }
237}
238
239/// Convert an identifier to a camel-cased identifier.
240/// Checks if a parameter type accepts NULL.
241///
242/// `Option<T>` always accepts NULL. `OptionalArg<T>` delegates to `T`.
243/// `Datum` accepts NULL (it passes through raw values including null).
244/// Everything else (references, concrete types) rejects NULL.
245fn is_nullable_type(ty: &syn::Type) -> bool {
246    if let syn::Type::Path(type_path) = ty {
247        if let Some(last_segment) = type_path.path.segments.last() {
248            let ident = &last_segment.ident;
249            if ident == "Option" || ident == "Datum" {
250                return true;
251            }
252            if ident == "OptionalArg" {
253                // OptionalArg<T> delegates nullability to T.
254                if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
255                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
256                        return is_nullable_type(inner_ty);
257                    }
258                }
259                return false;
260            }
261        }
262    }
263    false
264}
265
266/// Checks if a type is `Variadic<T>`.
267fn is_variadic_type(ty: &syn::Type) -> bool {
268    if let syn::Type::Path(type_path) = ty {
269        if let Some(last_segment) = type_path.path.segments.last() {
270            return last_segment.ident == "Variadic";
271        }
272    }
273    false
274}
275
276/// For a `Variadic<T>` type, checks if `T` accepts NULL.
277fn variadic_element_is_nullable(ty: &syn::Type) -> bool {
278    if let syn::Type::Path(type_path) = ty {
279        if let Some(last_segment) = type_path.path.segments.last() {
280            if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
281                if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
282                    return is_nullable_type(inner_ty);
283                }
284            }
285        }
286    }
287    false
288}
289
290/// Generates per-position nullability checks for non-nullable parameters.
291///
292/// For each parameter that rejects NULL (not `Option`, not `OptionalArg<Option<..>>`),
293/// generates a check that the corresponding input position is nullable. For `Variadic<T>`
294/// with non-nullable `T`, generates a check over all remaining input positions.
295fn non_nullable_position_checks(param_types: &[syn::Type]) -> Vec<TokenStream> {
296    let mut checks = Vec::new();
297    for (i, ty) in param_types.iter().enumerate() {
298        if is_variadic_type(ty) {
299            if !variadic_element_is_nullable(ty) {
300                checks.push(quote! { || input_types.iter().skip(#i).any(|t| t.nullable) });
301            }
302        } else if !is_nullable_type(ty) {
303            checks.push(quote! { || input_types.get(#i).map_or(false, |t| t.nullable) });
304        }
305    }
306    checks
307}
308
309fn camel_case(ident: &Ident) -> Ident {
310    let mut result = String::new();
311    let mut capitalize_next = true;
312    for c in ident.to_string().chars() {
313        if c == '_' {
314            capitalize_next = true;
315        } else if capitalize_next {
316            result.push(c.to_ascii_uppercase());
317            capitalize_next = false;
318        } else {
319            result.push(c);
320        }
321    }
322    Ident::new(&result, ident.span())
323}
324
325/// Extracts generic type parameters from a function signature.
326/// Returns an empty Vec if there are no type parameters.
327fn find_generic_type_params(func: &syn::ItemFn) -> Vec<Ident> {
328    func.sig
329        .generics
330        .params
331        .iter()
332        .filter_map(|p| {
333            if let syn::GenericParam::Type(tp) = p {
334                Some(tp.ident.clone())
335            } else {
336                None
337            }
338        })
339        .collect()
340}
341
342/// How a generic type parameter `T` appears in a type.
343#[derive(Debug, Clone)]
344enum GenericUsage {
345    /// `T` does not appear in this type.
346    Absent,
347    /// `T` appears bare (possibly wrapped in `Option` or `Result`).
348    Bare,
349    /// `T` appears inside a container type (e.g. `DatumList<'a, T>`, `Array<'a, T>`).
350    /// The stored `syn::TypePath` is the container with `T` erased to `Datum<'a>`.
351    InContainer(syn::TypePath),
352}
353
354impl PartialEq for GenericUsage {
355    fn eq(&self, other: &Self) -> bool {
356        match (self, other) {
357            (GenericUsage::Absent, GenericUsage::Absent) => true,
358            (GenericUsage::Bare, GenericUsage::Bare) => true,
359            (GenericUsage::InContainer(a), GenericUsage::InContainer(b)) => {
360                container_idents_match(a, b)
361            }
362            _ => false,
363        }
364    }
365}
366
367impl Eq for GenericUsage {}
368
369/// Compare two container type paths by their ident segments (ignoring lifetimes
370/// and generic arguments). Two containers are "same" if their path idents match.
371///
372/// This is safe because after erasure all container types have the same generic
373/// arity (lifetimes + `Datum<'a>`), so ident equality implies structural equality.
374fn container_idents_match(a: &syn::TypePath, b: &syn::TypePath) -> bool {
375    let a_idents: Vec<_> = a.path.segments.iter().map(|s| &s.ident).collect();
376    let b_idents: Vec<_> = b.path.segments.iter().map(|s| &s.ident).collect();
377    a_idents == b_idents
378}
379
380/// Classifies how a generic type parameter appears in a type.
381///
382/// Strips `Option<...>`, `Result<..., E>`, and `ExcludeNull<...>` wrappers before
383/// inspecting the inner type. Any generic type wrapping `T` that isn't `Option`,
384/// `Result`, or `ExcludeNull` is treated as a container. If the container doesn't
385/// implement `SqlContainerType`, the generated code won't compile (a clear error).
386fn classify_generic_usage(ty: &syn::Type, generic_name: &Ident) -> GenericUsage {
387    match ty {
388        syn::Type::Path(type_path) => {
389            if type_path.path.is_ident(generic_name) {
390                return GenericUsage::Bare;
391            }
392            if let Some(last) = type_path.path.segments.last() {
393                let ident_str = last.ident.to_string();
394                // Unwrap Option, Result, ExcludeNull wrappers
395                if ident_str == "Option" || ident_str == "Result" || ident_str == "ExcludeNull" {
396                    if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
397                        if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
398                            return classify_generic_usage(inner, generic_name);
399                        }
400                    }
401                }
402                // Check if any angle-bracketed arg contains the generic param.
403                // If so, treat this type as a container.
404                if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
405                    let has_generic_arg = args.args.iter().any(|arg| {
406                        if let syn::GenericArgument::Type(inner) = arg {
407                            type_contains_ident(inner, generic_name)
408                        } else {
409                            false
410                        }
411                    });
412                    if has_generic_arg {
413                        // Build the erased container type path (T → Datum<'a>).
414                        let erased = erase_generic_param(ty, generic_name);
415                        if let syn::Type::Path(erased_path) = erased {
416                            return GenericUsage::InContainer(erased_path);
417                        }
418                    }
419                    // Recurse into args for nested containers
420                    // (e.g., Option<DatumList<'a, T>> was already handled by
421                    // the Option unwrapping above, but handle other nestings)
422                    for arg in &args.args {
423                        if let syn::GenericArgument::Type(inner) = arg {
424                            let inner_usage = classify_generic_usage(inner, generic_name);
425                            if inner_usage != GenericUsage::Absent {
426                                return inner_usage;
427                            }
428                        }
429                    }
430                }
431            }
432            GenericUsage::Absent
433        }
434        syn::Type::Reference(r) => classify_generic_usage(&r.elem, generic_name),
435        syn::Type::Tuple(t) => {
436            // Prefer container usages over bare. For example, `(T, DatumList<'_, T>)`
437            // should classify as `InDatumList`, not `Bare`.
438            let mut best = GenericUsage::Absent;
439            for elem in &t.elems {
440                let usage = classify_generic_usage(elem, generic_name);
441                match (&best, &usage) {
442                    (GenericUsage::Absent, _) => best = usage,
443                    (GenericUsage::Bare, u) if *u != GenericUsage::Absent => best = usage.clone(),
444                    _ => {
445                        if usage != GenericUsage::Absent && usage != best {
446                            // Conflicting container usages — cannot resolve.
447                            return GenericUsage::Bare;
448                        }
449                    }
450                }
451            }
452            best
453        }
454        _ => GenericUsage::Absent,
455    }
456}
457
458/// Returns whether a type syntactically contains an identifier.
459fn type_contains_ident(ty: &syn::Type, ident: &Ident) -> bool {
460    match ty {
461        syn::Type::Path(type_path) => {
462            if type_path.path.is_ident(ident) {
463                return true;
464            }
465            if let Some(last) = type_path.path.segments.last() {
466                if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
467                    return args.args.iter().any(|arg| {
468                        if let syn::GenericArgument::Type(inner) = arg {
469                            type_contains_ident(inner, ident)
470                        } else {
471                            false
472                        }
473                    });
474                }
475            }
476            false
477        }
478        syn::Type::Reference(r) => type_contains_ident(&r.elem, ident),
479        syn::Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)),
480        _ => false,
481    }
482}
483
484/// Returns whether the outermost wrapper of a type is `Option`.
485fn is_option_wrapped(ty: &syn::Type) -> bool {
486    if let syn::Type::Path(type_path) = ty {
487        if let Some(last) = type_path.path.segments.last() {
488            return last.ident == "Option";
489        }
490    }
491    false
492}
493
494/// Derives an `output_type_expr` TokenStream from the structural relationship
495/// between input types and the output type, based on where generic parameters appear.
496///
497/// Finds the first generic parameter that appears in the output type, then looks for
498/// the first input parameter containing that generic in a container type to determine
499/// the unwrap strategy.
500///
501/// `is_unary` controls whether the generated expression uses `input_type`
502/// (singular, for unary functions) or `input_types[i]` (indexed, for binary/variadic).
503///
504/// Returns `None` if no generic parameter appears in the output type.
505fn derive_output_type_for_generics(
506    input_types: &[syn::Type],
507    output_ty: &syn::Type,
508    generic_names: &[Ident],
509    is_unary: bool,
510) -> darling::Result<Option<TokenStream>> {
511    // Find the first generic param that appears in the output.
512    let generic_name = match generic_names
513        .iter()
514        .find(|gn| classify_generic_usage(output_ty, gn) != GenericUsage::Absent)
515    {
516        Some(gn) => gn,
517        None => return Ok(None),
518    };
519    derive_output_type_for_generic(input_types, output_ty, generic_name, is_unary)
520}
521
522/// Derives an `output_type_expr` for a single generic parameter.
523///
524/// Uses `SqlContainerType` trait calls instead of matching on specific container
525/// type names. The generated code calls `<Container as SqlContainerType>::unwrap_element_type()`
526/// and `wrap_element_type()`, letting Rust's type system resolve the correct behavior.
527fn derive_output_type_for_generic(
528    input_types: &[syn::Type],
529    output_ty: &syn::Type,
530    generic_name: &Ident,
531    is_unary: bool,
532) -> darling::Result<Option<TokenStream>> {
533    let output_usage = classify_generic_usage(output_ty, generic_name);
534    if output_usage == GenericUsage::Absent {
535        return Ok(None);
536    }
537
538    let nullable = is_option_wrapped(output_ty);
539
540    // Find the first input parameter that has T in a container.
541    // Prefer container inputs over bare inputs.
542    let mut container_input: Option<(usize, GenericUsage)> = None;
543    for (i, ty) in input_types.iter().enumerate() {
544        let usage = classify_generic_usage(ty, generic_name);
545        match &usage {
546            GenericUsage::InContainer(_) => {
547                container_input = Some((i, usage));
548                break;
549            }
550            GenericUsage::Bare => {
551                // Bare T in input — not a container, keep looking for a container.
552                if container_input.is_none() {
553                    container_input = Some((i, usage));
554                }
555            }
556            GenericUsage::Absent => {}
557        }
558    }
559
560    let (pos, source_usage) = container_input.ok_or_else(|| {
561        darling::Error::custom(
562            "generic parameter T appears in the output type but not in any input type",
563        )
564    })?;
565
566    // Generate the base expression to access the input type.
567    let input_access = if is_unary {
568        quote! { input_type }
569    } else {
570        let pos_lit = syn::Index::from(pos);
571        quote! { input_types[#pos_lit] }
572    };
573
574    // For multi-input functions, generate soft assertions that all inputs
575    // carrying T agree on the SQL element type. This catches bugs in the
576    // planner's overload resolution or cast insertion.
577    let consistency_checks = if !is_unary {
578        let mut checks = Vec::new();
579        for (i, ty) in input_types.iter().enumerate() {
580            if i == pos {
581                continue;
582            }
583            let usage = classify_generic_usage(ty, generic_name);
584            if usage == GenericUsage::Absent {
585                continue;
586            }
587            let primary_elem = element_type_expr(&input_access, &source_usage);
588            let i_lit = syn::Index::from(i);
589            let other_access = quote! { input_types[#i_lit] };
590            let other_elem = element_type_expr(&other_access, &usage);
591            let generic_str = generic_name.to_string();
592            checks.push(quote! {
593                mz_ore::soft_assert_or_log!(
594                    #primary_elem.base_eq(#other_elem),
595                    "auto-derived sqlfunc output type inference found inconsistent \
596                     SQL types for generic {} across inputs: {:?} vs {:?}; \
597                     this indicates a bug in polymorphic coercion, builtin \
598                     declaration, or sqlfunc inference",
599                    #generic_str,
600                    #primary_elem,
601                    #other_elem,
602                );
603            });
604        }
605        quote! { #(#checks)* }
606    } else {
607        quote! {}
608    };
609
610    // Now generate the output_type_expr based on the combination of
611    // source container and output usage.
612    let expr = match (&output_usage, &source_usage) {
613        // Output is bare T, source is a container → unwrap element type via trait.
614        (GenericUsage::Bare, GenericUsage::InContainer(in_container)) => {
615            let in_c = elide_lifetimes(in_container);
616            quote! {
617                {
618                    #consistency_checks
619                    <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
620                        &#input_access.scalar_type
621                    ).clone().nullable(#nullable)
622                }
623            }
624        }
625        // Output is bare T, source is bare T → forward input type directly.
626        (GenericUsage::Bare, GenericUsage::Bare) => {
627            quote! {
628                {
629                    #consistency_checks
630                    #input_access.scalar_type.clone().nullable(#nullable)
631                }
632            }
633        }
634        // Output is a container, source is a container (same or different) →
635        // unwrap from input container, wrap into output container via traits.
636        (GenericUsage::InContainer(out_container), GenericUsage::InContainer(in_container)) => {
637            let out_c = elide_lifetimes(out_container);
638            let in_c = elide_lifetimes(in_container);
639            quote! {
640                {
641                    #consistency_checks
642                    <#out_c as mz_repr::SqlContainerType>::wrap_element_type(
643                        <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
644                            &#input_access.scalar_type
645                        ).clone()
646                    ).nullable(#nullable)
647                }
648            }
649        }
650        // Other cases — user must provide explicit output_type_expr.
651        _ => {
652            return Err(darling::Error::custom(format!(
653                "cannot auto-derive output_type_expr: output uses T as {:?} but \
654                 the first T-containing input uses T as {:?}",
655                output_usage, source_usage
656            )));
657        }
658    };
659
660    Ok(Some(expr))
661}
662
663/// Generates a token stream that extracts the T-level SQL type from an input
664/// access expression, based on how T is used in that input.
665fn element_type_expr(input_access: &TokenStream, usage: &GenericUsage) -> TokenStream {
666    match usage {
667        GenericUsage::Bare => {
668            quote! { &#input_access.scalar_type }
669        }
670        GenericUsage::InContainer(container) => {
671            let c = elide_lifetimes(container);
672            quote! {
673                <#c as mz_repr::SqlContainerType>::unwrap_element_type(
674                    &#input_access.scalar_type
675                )
676            }
677        }
678        GenericUsage::Absent => unreachable!("element_type_expr called with Absent usage"),
679    }
680}
681
682/// Replaces all lifetime parameters in a `syn::TypePath` with `'_`.
683///
684/// Used for container type paths in turbofish position (e.g.
685/// `<DatumList<'_, Datum<'_>> as SqlContainerType>::...`).
686/// The `output_sql_type` method's `&self` provides an implicit lifetime
687/// that the compiler can infer through `'_`.
688fn elide_lifetimes(tp: &syn::TypePath) -> syn::TypePath {
689    let mut tp = tp.clone();
690    for segment in &mut tp.path.segments {
691        if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
692            for arg in &mut args.args {
693                match arg {
694                    syn::GenericArgument::Lifetime(lt) => {
695                        *lt = Lifetime::new("'_", lt.span());
696                    }
697                    syn::GenericArgument::Type(ty) => {
698                        elide_lifetimes_in_type(ty);
699                    }
700                    _ => {}
701                }
702            }
703        }
704    }
705    tp
706}
707
708/// Recursively replaces all lifetime parameters in a `syn::Type` with `'_`.
709fn elide_lifetimes_in_type(ty: &mut syn::Type) {
710    match ty {
711        syn::Type::Path(tp) => {
712            *tp = elide_lifetimes(tp);
713        }
714        syn::Type::Reference(r) => {
715            if let Some(lt) = &mut r.lifetime {
716                *lt = Lifetime::new("'_", lt.span());
717            }
718            elide_lifetimes_in_type(&mut r.elem);
719        }
720        syn::Type::Tuple(t) => {
721            for elem in &mut t.elems {
722                elide_lifetimes_in_type(elem);
723            }
724        }
725        _ => {}
726    }
727}
728
729/// Replaces occurrences of a generic type parameter with `Datum<'a>` in a type.
730///
731/// Used to convert types from the user's generic function signature into concrete
732/// types for the generated trait impl's associated types, where `T` is not in scope.
733fn erase_generic_param(ty: &syn::Type, generic_name: &Ident) -> syn::Type {
734    match ty {
735        syn::Type::Path(type_path) => {
736            if type_path.path.is_ident(generic_name) {
737                return syn::parse_quote!(Datum<'a>);
738            }
739            let mut type_path = type_path.clone();
740            for segment in &mut type_path.path.segments {
741                if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
742                    for arg in &mut args.args {
743                        if let syn::GenericArgument::Type(inner) = arg {
744                            *inner = erase_generic_param(inner, generic_name);
745                        }
746                    }
747                }
748            }
749            syn::Type::Path(type_path)
750        }
751        syn::Type::Reference(r) => {
752            let elem = Box::new(erase_generic_param(&r.elem, generic_name));
753            syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
754        }
755        syn::Type::Tuple(t) => {
756            let elems = t
757                .elems
758                .iter()
759                .map(|e| erase_generic_param(e, generic_name))
760                .collect();
761            syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
762        }
763        _ => ty.clone(),
764    }
765}
766
767/// Erases all generic type parameters from a type, replacing each with `Datum<'a>`.
768fn erase_all_generic_params(ty: &syn::Type, generic_names: &[Ident]) -> syn::Type {
769    let mut ty = ty.clone();
770    for gn in generic_names {
771        ty = erase_generic_param(&ty, gn);
772    }
773    ty
774}
775
776/// Determines the argument type of the nth argument of the function.
777///
778/// Adds a lifetime `'a` to the argument type if it is a reference type.
779///
780/// Panics if the function has fewer than `nth` arguments. Returns an error if
781/// the parameter is a `self` receiver.
782fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
783    match &arg.sig.inputs[nth] {
784        syn::FnArg::Typed(pat) => {
785            // Patch lifetimes to be 'a if reference
786            if let syn::Type::Reference(r) = &*pat.ty {
787                if r.lifetime.is_none() {
788                    let ty = syn::Type::Reference(syn::TypeReference {
789                        lifetime: Some(Lifetime::new("'a", r.span())),
790                        ..r.clone()
791                    });
792                    return Ok(ty);
793                }
794            }
795            Ok((*pat.ty).clone())
796        }
797        syn::FnArg::Receiver(_) => Err(syn::Error::new(
798            arg.sig.inputs[nth].span(),
799            "Unsupported argument type",
800        )),
801    }
802}
803
804/// Recursively patches lifetimes in a type, adding `'a` to references without a lifetime
805/// and recursing into generic arguments and tuples.
806fn patch_lifetimes(ty: &syn::Type) -> syn::Type {
807    match ty {
808        syn::Type::Reference(r) => {
809            let elem = Box::new(patch_lifetimes(&r.elem));
810            if r.lifetime.is_none() {
811                syn::Type::Reference(syn::TypeReference {
812                    lifetime: Some(Lifetime::new("'a", r.span())),
813                    elem,
814                    ..r.clone()
815                })
816            } else {
817                syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
818            }
819        }
820        syn::Type::Tuple(t) => {
821            let elems = t.elems.iter().map(patch_lifetimes).collect();
822            syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
823        }
824        syn::Type::Path(p) => {
825            let mut p = p.clone();
826            for segment in &mut p.path.segments {
827                if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
828                    for arg in &mut args.args {
829                        if let syn::GenericArgument::Type(ty) = arg {
830                            *ty = patch_lifetimes(ty);
831                        }
832                    }
833                }
834            }
835            syn::Type::Path(p)
836        }
837        _ => ty.clone(),
838    }
839}
840
841/// Determine the output type for a function. Returns an error if the function
842/// does not return a value.
843fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
844    match &arg.sig.output {
845        syn::ReturnType::Type(_, ty) => Ok(&*ty),
846        syn::ReturnType::Default => Err(syn::Error::new(
847            arg.sig.output.span(),
848            "Function needs to return a value",
849        )),
850    }
851}
852
853/// Produce a `EagerUnaryFunc` implementation.
854fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
855    let fn_name = &func.sig.ident;
856    let struct_name = camel_case(&func.sig.ident);
857    let input_ty_raw = arg_type(func, 0)?;
858    let output_ty_raw = output_type(func)?;
859    let generic_params = find_generic_type_params(func);
860    // Erase generic type params → Datum<'a> for use in the trait impl's associated types.
861    let input_ty = erase_all_generic_params(&input_ty_raw, &generic_params);
862    let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
863    let Modifiers {
864        is_monotone,
865        sqlname,
866        preserves_uniqueness,
867        inverse,
868        is_infix_op,
869        output_type,
870        mut output_type_expr,
871        negate,
872        could_error,
873        propagates_nulls,
874        mut introduces_nulls,
875        is_associative,
876        is_eliminable_cast,
877        test: _,
878    } = modifiers;
879
880    // If generic type parameters are present and no explicit output_type_expr,
881    // auto-derive one from the structural relationship between input and output types.
882    // Use raw (pre-erasure) types so we can see the generic parameters.
883    if !generic_params.is_empty() {
884        if output_type_expr.is_none() && output_type.is_none() {
885            if let Some(derived) = derive_output_type_for_generics(
886                &[input_ty_raw],
887                output_ty_raw,
888                &generic_params,
889                true,
890            )? {
891                output_type_expr = Some(syn::parse2(derived)?);
892                if introduces_nulls.is_none() {
893                    let nullable = is_option_wrapped(output_ty_raw);
894                    introduces_nulls = Some(syn::parse_quote!(#nullable));
895                }
896            }
897        }
898    }
899
900    if is_infix_op.is_some() {
901        return Err(darling::Error::unknown_field(
902            "is_infix_op not supported for unary functions",
903        ));
904    }
905    if output_type.is_some() && output_type_expr.is_some() {
906        return Err(darling::Error::unknown_field(
907            "output_type and output_type_expr cannot be used together",
908        ));
909    }
910    if output_type_expr.is_some() && introduces_nulls.is_none() {
911        return Err(darling::Error::unknown_field(
912            "output_type_expr requires introduces_nulls",
913        ));
914    }
915    if negate.is_some() {
916        return Err(darling::Error::unknown_field(
917            "negate not supported for unary functions",
918        ));
919    }
920    if propagates_nulls.is_some() {
921        return Err(darling::Error::unknown_field(
922            "propagates_nulls not supported for unary functions",
923        ));
924    }
925    if is_associative.is_some() {
926        return Err(darling::Error::unknown_field(
927            "is_associative not supported for unary functions",
928        ));
929    }
930
931    let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
932        quote! {
933            fn preserves_uniqueness(&self) -> bool {
934                #preserves_uniqueness
935            }
936        }
937    });
938
939    let inverse_fn = inverse.as_ref().map(|inverse| {
940        quote! {
941            fn inverse(&self) -> Option<crate::UnaryFunc> {
942                #inverse
943            }
944        }
945    });
946
947    let is_monotone_fn = is_monotone.map(|is_monotone| {
948        quote! {
949            fn is_monotone(&self) -> bool {
950                #is_monotone
951            }
952        }
953    });
954
955    let name = sqlname
956        .as_ref()
957        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
958
959    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
960        let introduces_nulls_fn = quote! {
961            fn introduces_nulls(&self) -> bool {
962                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
963            }
964        };
965        let output_type = quote! { <#output_type>::as_column_type() };
966        (output_type, Some(introduces_nulls_fn))
967    } else {
968        (quote! { Self::Output::as_column_type() }, None)
969    };
970
971    if let Some(output_type_expr) = output_type_expr {
972        output_type = quote! { #output_type_expr };
973    }
974
975    if let Some(introduces_nulls) = introduces_nulls {
976        introduces_nulls_fn = Some(quote! {
977            fn introduces_nulls(&self) -> bool {
978                #introduces_nulls
979            }
980        });
981    }
982
983    let could_error_fn = could_error.map(|could_error| {
984        quote! {
985            fn could_error(&self) -> bool {
986                #could_error
987            }
988        }
989    });
990
991    let is_eliminable_cast_fn = is_eliminable_cast.map(|is_eliminable_cast| {
992        quote! {
993            fn is_eliminable_cast(&self) -> bool {
994                #is_eliminable_cast
995            }
996        }
997    });
998
999    let result = quote! {
1000        #[derive(
1001            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1002            Debug, Eq, PartialEq, serde::Serialize,
1003            serde::Deserialize, Hash, mz_lowertest::MzReflect,
1004        )]
1005        pub struct #struct_name;
1006
1007        impl crate::func::EagerUnaryFunc for #struct_name {
1008            type Input<'a> = #input_ty;
1009            type Output<'a> = #output_ty;
1010
1011            fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
1012                #fn_name(a)
1013            }
1014
1015            fn output_sql_type(
1016                &self,
1017                input_type: mz_repr::SqlColumnType
1018            ) -> mz_repr::SqlColumnType {
1019                use mz_repr::AsColumnType;
1020                let output = #output_type;
1021                let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
1022                let nullable = output.nullable;
1023                // The output is nullable if it is nullable by itself or the input is nullable
1024                // and this function propagates nulls
1025                output.nullable(nullable || (propagates_nulls && input_type.nullable))
1026            }
1027
1028            #could_error_fn
1029            #introduces_nulls_fn
1030            #inverse_fn
1031            #is_monotone_fn
1032            #preserves_uniqueness_fn
1033            #is_eliminable_cast_fn
1034        }
1035
1036        impl std::fmt::Display for #struct_name {
1037            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1038                f.write_str(#name)
1039            }
1040        }
1041
1042        #func
1043    };
1044    Ok(result)
1045}
1046
1047/// Produce a `EagerBinaryFunc` implementation.
1048fn binary_func(
1049    func: &syn::ItemFn,
1050    modifiers: Modifiers,
1051    arena: bool,
1052) -> darling::Result<TokenStream> {
1053    let fn_name = &func.sig.ident;
1054    let struct_name = camel_case(&func.sig.ident);
1055    let input1_ty_raw = arg_type(func, 0)?;
1056    let input2_ty_raw = arg_type(func, 1)?;
1057    let output_ty_raw = output_type(func)?;
1058    let generic_params = find_generic_type_params(func);
1059    // Erase generic type params → Datum<'a> for use in the trait impl's associated types.
1060    let input1_ty = erase_all_generic_params(&input1_ty_raw, &generic_params);
1061    let input2_ty = erase_all_generic_params(&input2_ty_raw, &generic_params);
1062    let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1063
1064    let Modifiers {
1065        is_monotone,
1066        sqlname,
1067        preserves_uniqueness,
1068        inverse,
1069        is_infix_op,
1070        output_type,
1071        mut output_type_expr,
1072        negate,
1073        could_error,
1074        propagates_nulls,
1075        mut introduces_nulls,
1076        is_associative,
1077        is_eliminable_cast,
1078        test: _,
1079    } = modifiers;
1080
1081    // Auto-derive output_type_expr from generic parameters, if applicable.
1082    // Use raw (pre-erasure) types so we can see the generic parameters.
1083    if !generic_params.is_empty() {
1084        if output_type_expr.is_none() && output_type.is_none() {
1085            if let Some(derived) = derive_output_type_for_generics(
1086                &[input1_ty_raw, input2_ty_raw],
1087                output_ty_raw,
1088                &generic_params,
1089                false,
1090            )? {
1091                output_type_expr = Some(syn::parse2(derived)?);
1092                if introduces_nulls.is_none() {
1093                    let nullable = is_option_wrapped(output_ty_raw);
1094                    introduces_nulls = Some(syn::parse_quote!(#nullable));
1095                }
1096            }
1097        }
1098    }
1099
1100    if preserves_uniqueness.is_some() {
1101        return Err(darling::Error::unknown_field(
1102            "preserves_uniqueness not supported for binary functions",
1103        ));
1104    }
1105    if inverse.is_some() {
1106        return Err(darling::Error::unknown_field(
1107            "inverse not supported for binary functions",
1108        ));
1109    }
1110    if output_type.is_some() && output_type_expr.is_some() {
1111        return Err(darling::Error::unknown_field(
1112            "output_type and output_type_expr cannot be used together",
1113        ));
1114    }
1115    if output_type_expr.is_some() && introduces_nulls.is_none() {
1116        return Err(darling::Error::unknown_field(
1117            "output_type_expr requires introduces_nulls",
1118        ));
1119    }
1120    if is_associative.is_some() {
1121        return Err(darling::Error::unknown_field(
1122            "is_associative not supported for binary functions",
1123        ));
1124    }
1125    if is_eliminable_cast.is_some() {
1126        return Err(darling::Error::unknown_field(
1127            "is_eliminable_cast not supported for binary functions",
1128        ));
1129    }
1130
1131    let negate_fn = negate.map(|negate| {
1132        quote! {
1133            fn negate(&self) -> Option<crate::BinaryFunc> {
1134                #negate
1135            }
1136        }
1137    });
1138
1139    let is_monotone_fn = is_monotone.map(|is_monotone| {
1140        quote! {
1141            fn is_monotone(&self) -> (bool, bool) {
1142                #is_monotone
1143            }
1144        }
1145    });
1146
1147    let name = sqlname
1148        .as_ref()
1149        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1150
1151    let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1152        let introduces_nulls_fn = quote! {
1153            fn introduces_nulls(&self) -> bool {
1154                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1155            }
1156        };
1157        let output_type = quote! { <#output_type>::as_column_type() };
1158        (output_type, Some(introduces_nulls_fn))
1159    } else {
1160        (quote! { Self::Output::as_column_type() }, None)
1161    };
1162
1163    if let Some(output_type_expr) = output_type_expr {
1164        output_type = quote! { #output_type_expr };
1165    }
1166
1167    if let Some(introduces_nulls) = introduces_nulls {
1168        introduces_nulls_fn = Some(quote! {
1169            fn introduces_nulls(&self) -> bool {
1170                #introduces_nulls
1171            }
1172        });
1173    }
1174
1175    let arena = if arena {
1176        quote! { , temp_storage }
1177    } else {
1178        quote! {}
1179    };
1180
1181    let could_error_fn = could_error.map(|could_error| {
1182        quote! {
1183            fn could_error(&self) -> bool {
1184                #could_error
1185            }
1186        }
1187    });
1188
1189    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1190        quote! {
1191            fn is_infix_op(&self) -> bool {
1192                #is_infix_op
1193            }
1194        }
1195    });
1196
1197    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1198        quote! {
1199            fn propagates_nulls(&self) -> bool {
1200                #propagates_nulls
1201            }
1202        }
1203    });
1204
1205    // Per-position checks: for each non-nullable parameter, check if
1206    // the corresponding input column is nullable.
1207    let binary_non_nullable_checks =
1208        non_nullable_position_checks(&[input1_ty.clone(), input2_ty.clone()]);
1209
1210    let result = quote! {
1211        #[derive(
1212            proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1213            Debug, Eq, PartialEq, serde::Serialize,
1214            serde::Deserialize, Hash, mz_lowertest::MzReflect,
1215        )]
1216        pub struct #struct_name;
1217
1218        impl crate::func::binary::EagerBinaryFunc for #struct_name {
1219            type Input<'a> = (#input1_ty, #input2_ty);
1220            type Output<'a> = #output_ty;
1221
1222            fn call<'a>(
1223                &self,
1224                (a, b): Self::Input<'a>,
1225                temp_storage: &'a mz_repr::RowArena
1226            ) -> Self::Output<'a> {
1227                #fn_name(a, b #arena)
1228            }
1229
1230            fn output_sql_type(
1231                &self,
1232                input_types: &[mz_repr::SqlColumnType],
1233            ) -> mz_repr::SqlColumnType {
1234                use mz_repr::AsColumnType;
1235                let output = #output_type;
1236                let propagates_nulls =
1237                    crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
1238                let nullable = output.nullable;
1239                // The output is nullable if:
1240                // 1. The function introduces nulls (output.nullable), or
1241                // 2. A non-nullable parameter's input is nullable (will reject
1242                //    NULL at runtime via try_from_iter), or
1243                // 3. propagates_nulls is true and any input is nullable
1244                //    (optimizer short-circuits all-NULL inputs)
1245                let non_nullable_input_is_nullable =
1246                    false #(#binary_non_nullable_checks)*;
1247                let inputs_nullable = input_types.iter().any(|it| it.nullable);
1248                let is_null = nullable
1249                    || non_nullable_input_is_nullable
1250                    || (propagates_nulls && inputs_nullable);
1251                output.nullable(is_null)
1252            }
1253
1254            #could_error_fn
1255            #introduces_nulls_fn
1256            #is_infix_op_fn
1257            #is_monotone_fn
1258            #negate_fn
1259            #propagates_nulls_fn
1260        }
1261
1262        impl std::fmt::Display for #struct_name {
1263            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1264                f.write_str(#name)
1265            }
1266        }
1267
1268        #func
1269
1270    };
1271    Ok(result)
1272}
1273
1274/// Produce an `EagerVariadicFunc` implementation.
1275///
1276/// Two modes based on whether the function has a `&self` receiver:
1277/// * `&self` present: struct defined externally, generates method impl + trait impl + Display
1278/// * No `&self`: generates unit struct + trait impl + Display + preserves original function
1279fn variadic_func(
1280    func: &syn::ItemFn,
1281    modifiers: Modifiers,
1282    struct_ty: Option<syn::Path>,
1283    arena: bool,
1284    has_self: bool,
1285) -> darling::Result<TokenStream> {
1286    let fn_name = &func.sig.ident;
1287    let output_ty_raw = output_type(func)?;
1288    let generic_params = find_generic_type_params(func);
1289    let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1290    let struct_name = struct_ty
1291        .as_ref()
1292        .and_then(|ty| ty.segments.last())
1293        .map_or_else(|| camel_case(fn_name), |seg| seg.ident.clone());
1294
1295    let Modifiers {
1296        is_monotone,
1297        sqlname,
1298        preserves_uniqueness,
1299        inverse,
1300        is_infix_op,
1301        output_type,
1302        mut output_type_expr,
1303        negate,
1304        could_error,
1305        propagates_nulls,
1306        mut introduces_nulls,
1307        is_associative,
1308        is_eliminable_cast,
1309        test: _,
1310    } = modifiers;
1311
1312    // Reject modifiers that don't apply to variadic functions.
1313    if preserves_uniqueness.is_some() {
1314        return Err(darling::Error::unknown_field(
1315            "preserves_uniqueness not supported for variadic functions",
1316        ));
1317    }
1318    if inverse.is_some() {
1319        return Err(darling::Error::unknown_field(
1320            "inverse not supported for variadic functions",
1321        ));
1322    }
1323    if negate.is_some() {
1324        return Err(darling::Error::unknown_field(
1325            "negate not supported for variadic functions",
1326        ));
1327    }
1328    if is_eliminable_cast.is_some() {
1329        return Err(darling::Error::unknown_field(
1330            "is_eliminable_cast not supported for variadic functions",
1331        ));
1332    }
1333    if output_type.is_some() && output_type_expr.is_some() {
1334        return Err(darling::Error::unknown_field(
1335            "output_type and output_type_expr cannot be used together",
1336        ));
1337    }
1338    if output_type_expr.is_some() && introduces_nulls.is_none() {
1339        return Err(darling::Error::unknown_field(
1340            "output_type_expr requires introduces_nulls",
1341        ));
1342    }
1343
1344    // Collect input parameters (skip &self, skip &RowArena).
1345    let start = if has_self { 1 } else { 0 };
1346    let end = if arena {
1347        func.sig.inputs.len() - 1
1348    } else {
1349        func.sig.inputs.len()
1350    };
1351    let input_params: Vec<&syn::FnArg> = func
1352        .sig
1353        .inputs
1354        .iter()
1355        .skip(start)
1356        .take(end - start)
1357        .collect();
1358
1359    if input_params.is_empty() {
1360        return Err(darling::Error::custom(
1361            "variadic function must have at least one input parameter",
1362        ));
1363    }
1364
1365    // Extract parameter names and types.
1366    let mut param_names = Vec::new();
1367    let mut param_types = Vec::new();
1368    for param in &input_params {
1369        match param {
1370            syn::FnArg::Typed(pat) => {
1371                if let syn::Pat::Ident(ident) = &*pat.pat {
1372                    param_names.push(ident.ident.clone());
1373                } else {
1374                    return Err(
1375                        darling::Error::custom("unsupported parameter pattern").with_span(&pat.pat)
1376                    );
1377                }
1378                param_types.push(patch_lifetimes(&pat.ty));
1379            }
1380            syn::FnArg::Receiver(_) => {
1381                return Err(darling::Error::custom("unexpected self parameter"));
1382            }
1383        }
1384    }
1385
1386    // Auto-derive output_type_expr from generic parameters, if applicable.
1387    // Use raw (pre-erasure) types so we can see the generic parameters.
1388    if !generic_params.is_empty() {
1389        if output_type_expr.is_none() && output_type.is_none() {
1390            if let Some(derived) = derive_output_type_for_generics(
1391                &param_types,
1392                output_ty_raw,
1393                &generic_params,
1394                false,
1395            )? {
1396                output_type_expr = Some(syn::parse2(derived)?);
1397                if introduces_nulls.is_none() {
1398                    let nullable = is_option_wrapped(output_ty_raw);
1399                    introduces_nulls = Some(syn::parse_quote!(#nullable));
1400                }
1401            }
1402        }
1403    }
1404
1405    // Erase generic type params → Datum<'a> in param types for the trait impl's associated types.
1406    for ty in &mut param_types {
1407        *ty = erase_all_generic_params(ty, &generic_params);
1408    }
1409
1410    // Build input type: single param = bare type, multiple = tuple.
1411    let input_type: syn::Type = if param_types.len() == 1 {
1412        param_types[0].clone()
1413    } else {
1414        syn::parse_quote! { (#(#param_types),*) }
1415    };
1416
1417    // Build destructure pattern for call.
1418    let destructure = if param_names.len() == 1 {
1419        let name = &param_names[0];
1420        quote! { #name }
1421    } else {
1422        quote! { (#(#param_names),*) }
1423    };
1424
1425    let arena_arg = if arena {
1426        quote! { , temp_storage }
1427    } else {
1428        quote! {}
1429    };
1430
1431    let call_expr = if has_self {
1432        quote! { self.#fn_name(#(#param_names),* #arena_arg) }
1433    } else {
1434        quote! { #fn_name(#(#param_names),* #arena_arg) }
1435    };
1436
1437    // Build modifier functions.
1438    let name = sqlname
1439        .as_ref()
1440        .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1441
1442    let (mut output_type_code, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1443        let introduces_nulls_fn = quote! {
1444            fn introduces_nulls(&self) -> bool {
1445                <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1446            }
1447        };
1448        let output_type_code = quote! { <#output_type>::as_column_type() };
1449        (output_type_code, Some(introduces_nulls_fn))
1450    } else {
1451        (quote! { Self::Output::as_column_type() }, None)
1452    };
1453
1454    if let Some(output_type_expr) = output_type_expr {
1455        output_type_code = quote! { #output_type_expr };
1456    }
1457
1458    if let Some(introduces_nulls) = introduces_nulls {
1459        introduces_nulls_fn = Some(quote! {
1460            fn introduces_nulls(&self) -> bool {
1461                #introduces_nulls
1462            }
1463        });
1464    }
1465
1466    let could_error_fn = could_error.map(|could_error| {
1467        quote! {
1468            fn could_error(&self) -> bool {
1469                #could_error
1470            }
1471        }
1472    });
1473
1474    let is_monotone_fn = is_monotone.map(|is_monotone| {
1475        quote! {
1476            fn is_monotone(&self) -> bool {
1477                #is_monotone
1478            }
1479        }
1480    });
1481
1482    let is_associative_fn = is_associative.map(|is_associative| {
1483        quote! {
1484            fn is_associative(&self) -> bool {
1485                #is_associative
1486            }
1487        }
1488    });
1489
1490    let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1491        quote! {
1492            fn is_infix_op(&self) -> bool {
1493                #is_infix_op
1494            }
1495        }
1496    });
1497
1498    let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1499        quote! {
1500            fn propagates_nulls(&self) -> bool {
1501                #propagates_nulls
1502            }
1503        }
1504    });
1505
1506    // Per-position checks: for each non-nullable parameter, check if
1507    // the corresponding input column is nullable.
1508    let non_nullable_checks = non_nullable_position_checks(&param_types);
1509
1510    let trait_impl = quote! {
1511        impl crate::func::variadic::EagerVariadicFunc for #struct_name {
1512            type Input<'a> = #input_type;
1513            type Output<'a> = #output_ty;
1514
1515            fn call<'a>(
1516                &self,
1517                #destructure: Self::Input<'a>,
1518                temp_storage: &'a mz_repr::RowArena,
1519            ) -> Self::Output<'a> {
1520                #call_expr
1521            }
1522
1523            fn output_type(
1524                &self,
1525                input_types: &[mz_repr::SqlColumnType],
1526            ) -> mz_repr::SqlColumnType {
1527                use mz_repr::AsColumnType;
1528                let output = #output_type_code;
1529                let propagates_nulls =
1530                    crate::func::variadic::EagerVariadicFunc::propagates_nulls(self);
1531                let nullable = output.nullable;
1532                // The output is nullable if:
1533                // 1. The function introduces nulls (output.nullable), or
1534                // 2. A non-nullable parameter's input is nullable (will reject
1535                //    NULL at runtime via try_from_iter), or
1536                // 3. propagates_nulls is true and any input is nullable
1537                //    (optimizer short-circuits all-NULL inputs)
1538                let non_nullable_input_is_nullable =
1539                    false #(#non_nullable_checks)*;
1540                let inputs_nullable = input_types.iter().any(|it| it.nullable);
1541                output.nullable(
1542                    nullable
1543                    || non_nullable_input_is_nullable
1544                    || (propagates_nulls && inputs_nullable)
1545                )
1546            }
1547
1548            #could_error_fn
1549            #introduces_nulls_fn
1550            #is_infix_op_fn
1551            #is_monotone_fn
1552            #is_associative_fn
1553            #propagates_nulls_fn
1554        }
1555    };
1556
1557    let display_impl = quote! {
1558        impl std::fmt::Display for #struct_name {
1559            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1560                f.write_str(#name)
1561            }
1562        }
1563    };
1564
1565    let result = if has_self {
1566        // External struct: generate method impl + trait impl + Display.
1567        quote! {
1568            impl #struct_name {
1569                #func
1570            }
1571            #trait_impl
1572            #display_impl
1573        }
1574    } else {
1575        // Unit struct: generate struct + trait impl + Display + original function.
1576        quote! {
1577            #[derive(
1578                proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1579                Debug, Eq, PartialEq, serde::Serialize,
1580                serde::Deserialize, Hash, mz_lowertest::MzReflect,
1581            )]
1582            pub struct #struct_name;
1583
1584            #trait_impl
1585            #display_impl
1586
1587            #func
1588        }
1589    };
1590
1591    Ok(result)
1592}