1use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19 is_monotone: Option<Expr>,
22 sqlname: Option<SqlName>,
24 preserves_uniqueness: Option<Expr>,
26 inverse: Option<Expr>,
28 negate: Option<Expr>,
30 is_infix_op: Option<Expr>,
33 output_type: Option<syn::Path>,
35 output_type_expr: Option<Expr>,
37 could_error: Option<Expr>,
40 propagates_nulls: Option<Expr>,
42 introduces_nulls: Option<Expr>,
44 is_associative: Option<Expr>,
46 is_eliminable_cast: Option<Expr>,
48 test: Option<bool>,
50}
51
52#[derive(Debug)]
55enum SqlName {
56 Literal(syn::Lit),
58 Macro(syn::ExprMacro),
60}
61
62impl quote::ToTokens for SqlName {
63 fn to_tokens(&self, tokens: &mut TokenStream) {
64 let name = match self {
65 SqlName::Literal(lit) => quote! { #lit },
66 SqlName::Macro(mac) => quote! { #mac },
67 };
68 tokens.extend(name);
69 }
70}
71
72impl darling::FromMeta for SqlName {
73 fn from_value(value: &Lit) -> darling::Result<Self> {
74 Ok(Self::Literal(value.clone()))
75 }
76 fn from_expr(expr: &Expr) -> darling::Result<Self> {
77 match expr {
78 Expr::Lit(lit) => Self::from_value(&lit.lit),
79 Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
80 Expr::Group(mac) => Self::from_expr(&mac.expr),
83 _ => Err(darling::Error::unexpected_expr_type(expr)),
84 }
85 }
86}
87
88pub fn sqlfunc(
94 attr: TokenStream,
95 item: TokenStream,
96 include_test: bool,
97) -> darling::Result<TokenStream> {
98 let mut attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
99
100 let struct_ty = match attr_args.first() {
102 Some(darling::ast::NestedMeta::Meta(syn::Meta::Path(_))) => {
103 let darling::ast::NestedMeta::Meta(syn::Meta::Path(path)) = attr_args.remove(0) else {
104 unreachable!()
105 };
106 Some(path)
107 }
108 _ => None,
109 };
110
111 let modifiers = Modifiers::from_list(&attr_args).unwrap();
112 let generate_tests = modifiers.test.unwrap_or(false);
113 let func = syn::parse2::<syn::ItemFn>(item.clone())?;
114
115 let tokens = match determine_arity(&func) {
116 Arity::Nullary => Err(darling::Error::custom("Nullary functions not supported")),
117 Arity::Unary { arena: false } => unary_func(&func, modifiers),
118 Arity::Unary { arena: true } => Err(darling::Error::custom(
119 "Unary functions do not yet support RowArena.",
120 )),
121 Arity::Binary { arena } => binary_func(&func, modifiers, arena),
122 Arity::Variadic { arena, has_self } => {
123 variadic_func(&func, modifiers, struct_ty, arena, has_self)
124 }
125 }?;
126
127 let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
128
129 Ok(quote! {
130 #tokens
131 #test
132 })
133}
134
135#[cfg(any(feature = "test", test))]
136fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
137 let attr = attr.to_string();
138 let item = item.to_string();
139 let test_name = Ident::new(&format!("test_{}", name), name.span());
140 let fn_name = name.to_string();
141
142 quote! {
143 #[cfg(test)]
144 #[cfg_attr(miri, ignore)] #[mz_ore::test]
146 fn #test_name() {
147 let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
148 insta::assert_snapshot!(#fn_name, output, &input);
149 }
150 }
151}
152
153#[cfg(not(any(feature = "test", test)))]
154fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
155 quote! {}
156}
157
158fn last_is_arena(func: &syn::ItemFn) -> bool {
160 func.sig.inputs.last().map_or(false, |last| {
161 if let syn::FnArg::Typed(pat) = last {
162 if let syn::Type::Reference(reference) = &*pat.ty {
163 if let syn::Type::Path(path) = &*reference.elem {
164 return path.path.is_ident("RowArena");
165 }
166 }
167 }
168 false
169 })
170}
171
172enum Arity {
174 Nullary,
175 Unary { arena: bool },
176 Binary { arena: bool },
177 Variadic { arena: bool, has_self: bool },
178}
179
180fn is_variadic_arg(arg: &syn::FnArg) -> bool {
184 if let syn::FnArg::Typed(pat) = arg {
185 if let syn::Type::Path(path) = &*pat.ty {
186 if let Some(segment) = path.path.segments.last() {
187 let ident = segment.ident.to_string();
188 return ident == "Variadic" || ident == "OptionalArg";
189 }
190 }
191 }
192 false
193}
194
195fn determine_arity(func: &syn::ItemFn) -> Arity {
201 let arena = last_is_arena(func);
202 let has_self = matches!(func.sig.inputs.first(), Some(syn::FnArg::Receiver(_)));
203
204 let mut effective_count = func.sig.inputs.len();
205 if arena {
206 effective_count -= 1;
207 }
208 if has_self {
209 effective_count -= 1;
210 }
211
212 let start = if has_self { 1 } else { 0 };
214 let end = if arena {
215 func.sig.inputs.len() - 1
216 } else {
217 func.sig.inputs.len()
218 };
219 let has_variadic_param = func
220 .sig
221 .inputs
222 .iter()
223 .skip(start)
224 .take(end - start)
225 .any(is_variadic_arg);
226
227 if has_variadic_param || effective_count >= 3 {
228 Arity::Variadic { arena, has_self }
229 } else {
230 match effective_count {
231 0 => Arity::Nullary,
232 1 => Arity::Unary { arena },
233 2 => Arity::Binary { arena },
234 _ => unreachable!(),
235 }
236 }
237}
238
239fn is_nullable_type(ty: &syn::Type) -> bool {
246 if let syn::Type::Path(type_path) = ty {
247 if let Some(last_segment) = type_path.path.segments.last() {
248 let ident = &last_segment.ident;
249 if ident == "Option" || ident == "Datum" {
250 return true;
251 }
252 if ident == "OptionalArg" {
253 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
255 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
256 return is_nullable_type(inner_ty);
257 }
258 }
259 return false;
260 }
261 }
262 }
263 false
264}
265
266fn is_variadic_type(ty: &syn::Type) -> bool {
268 if let syn::Type::Path(type_path) = ty {
269 if let Some(last_segment) = type_path.path.segments.last() {
270 return last_segment.ident == "Variadic";
271 }
272 }
273 false
274}
275
276fn variadic_element_is_nullable(ty: &syn::Type) -> bool {
278 if let syn::Type::Path(type_path) = ty {
279 if let Some(last_segment) = type_path.path.segments.last() {
280 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
281 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
282 return is_nullable_type(inner_ty);
283 }
284 }
285 }
286 }
287 false
288}
289
290fn non_nullable_position_checks(param_types: &[syn::Type]) -> Vec<TokenStream> {
296 let mut checks = Vec::new();
297 for (i, ty) in param_types.iter().enumerate() {
298 if is_variadic_type(ty) {
299 if !variadic_element_is_nullable(ty) {
300 checks.push(quote! { || input_types.iter().skip(#i).any(|t| t.nullable) });
301 }
302 } else if !is_nullable_type(ty) {
303 checks.push(quote! { || input_types.get(#i).map_or(false, |t| t.nullable) });
304 }
305 }
306 checks
307}
308
309fn camel_case(ident: &Ident) -> Ident {
310 let mut result = String::new();
311 let mut capitalize_next = true;
312 for c in ident.to_string().chars() {
313 if c == '_' {
314 capitalize_next = true;
315 } else if capitalize_next {
316 result.push(c.to_ascii_uppercase());
317 capitalize_next = false;
318 } else {
319 result.push(c);
320 }
321 }
322 Ident::new(&result, ident.span())
323}
324
325fn find_generic_type_params(func: &syn::ItemFn) -> Vec<Ident> {
328 func.sig
329 .generics
330 .params
331 .iter()
332 .filter_map(|p| {
333 if let syn::GenericParam::Type(tp) = p {
334 Some(tp.ident.clone())
335 } else {
336 None
337 }
338 })
339 .collect()
340}
341
342#[derive(Debug, Clone)]
344enum GenericUsage {
345 Absent,
347 Bare,
349 InContainer(syn::TypePath),
352}
353
354impl PartialEq for GenericUsage {
355 fn eq(&self, other: &Self) -> bool {
356 match (self, other) {
357 (GenericUsage::Absent, GenericUsage::Absent) => true,
358 (GenericUsage::Bare, GenericUsage::Bare) => true,
359 (GenericUsage::InContainer(a), GenericUsage::InContainer(b)) => {
360 container_idents_match(a, b)
361 }
362 _ => false,
363 }
364 }
365}
366
367impl Eq for GenericUsage {}
368
369fn container_idents_match(a: &syn::TypePath, b: &syn::TypePath) -> bool {
375 let a_idents: Vec<_> = a.path.segments.iter().map(|s| &s.ident).collect();
376 let b_idents: Vec<_> = b.path.segments.iter().map(|s| &s.ident).collect();
377 a_idents == b_idents
378}
379
380fn classify_generic_usage(ty: &syn::Type, generic_name: &Ident) -> GenericUsage {
387 match ty {
388 syn::Type::Path(type_path) => {
389 if type_path.path.is_ident(generic_name) {
390 return GenericUsage::Bare;
391 }
392 if let Some(last) = type_path.path.segments.last() {
393 let ident_str = last.ident.to_string();
394 if ident_str == "Option" || ident_str == "Result" || ident_str == "ExcludeNull" {
396 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
397 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
398 return classify_generic_usage(inner, generic_name);
399 }
400 }
401 }
402 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
405 let has_generic_arg = args.args.iter().any(|arg| {
406 if let syn::GenericArgument::Type(inner) = arg {
407 type_contains_ident(inner, generic_name)
408 } else {
409 false
410 }
411 });
412 if has_generic_arg {
413 let erased = erase_generic_param(ty, generic_name);
415 if let syn::Type::Path(erased_path) = erased {
416 return GenericUsage::InContainer(erased_path);
417 }
418 }
419 for arg in &args.args {
423 if let syn::GenericArgument::Type(inner) = arg {
424 let inner_usage = classify_generic_usage(inner, generic_name);
425 if inner_usage != GenericUsage::Absent {
426 return inner_usage;
427 }
428 }
429 }
430 }
431 }
432 GenericUsage::Absent
433 }
434 syn::Type::Reference(r) => classify_generic_usage(&r.elem, generic_name),
435 syn::Type::Tuple(t) => {
436 let mut best = GenericUsage::Absent;
439 for elem in &t.elems {
440 let usage = classify_generic_usage(elem, generic_name);
441 match (&best, &usage) {
442 (GenericUsage::Absent, _) => best = usage,
443 (GenericUsage::Bare, u) if *u != GenericUsage::Absent => best = usage.clone(),
444 _ => {
445 if usage != GenericUsage::Absent && usage != best {
446 return GenericUsage::Bare;
448 }
449 }
450 }
451 }
452 best
453 }
454 _ => GenericUsage::Absent,
455 }
456}
457
458fn type_contains_ident(ty: &syn::Type, ident: &Ident) -> bool {
460 match ty {
461 syn::Type::Path(type_path) => {
462 if type_path.path.is_ident(ident) {
463 return true;
464 }
465 if let Some(last) = type_path.path.segments.last() {
466 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
467 return args.args.iter().any(|arg| {
468 if let syn::GenericArgument::Type(inner) = arg {
469 type_contains_ident(inner, ident)
470 } else {
471 false
472 }
473 });
474 }
475 }
476 false
477 }
478 syn::Type::Reference(r) => type_contains_ident(&r.elem, ident),
479 syn::Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)),
480 _ => false,
481 }
482}
483
484fn is_option_wrapped(ty: &syn::Type) -> bool {
486 if let syn::Type::Path(type_path) = ty {
487 if let Some(last) = type_path.path.segments.last() {
488 return last.ident == "Option";
489 }
490 }
491 false
492}
493
494fn derive_output_type_for_generics(
506 input_types: &[syn::Type],
507 output_ty: &syn::Type,
508 generic_names: &[Ident],
509 is_unary: bool,
510) -> darling::Result<Option<TokenStream>> {
511 let generic_name = match generic_names
513 .iter()
514 .find(|gn| classify_generic_usage(output_ty, gn) != GenericUsage::Absent)
515 {
516 Some(gn) => gn,
517 None => return Ok(None),
518 };
519 derive_output_type_for_generic(input_types, output_ty, generic_name, is_unary)
520}
521
522fn derive_output_type_for_generic(
528 input_types: &[syn::Type],
529 output_ty: &syn::Type,
530 generic_name: &Ident,
531 is_unary: bool,
532) -> darling::Result<Option<TokenStream>> {
533 let output_usage = classify_generic_usage(output_ty, generic_name);
534 if output_usage == GenericUsage::Absent {
535 return Ok(None);
536 }
537
538 let nullable = is_option_wrapped(output_ty);
539
540 let mut container_input: Option<(usize, GenericUsage)> = None;
543 for (i, ty) in input_types.iter().enumerate() {
544 let usage = classify_generic_usage(ty, generic_name);
545 match &usage {
546 GenericUsage::InContainer(_) => {
547 container_input = Some((i, usage));
548 break;
549 }
550 GenericUsage::Bare => {
551 if container_input.is_none() {
553 container_input = Some((i, usage));
554 }
555 }
556 GenericUsage::Absent => {}
557 }
558 }
559
560 let (pos, source_usage) = container_input.ok_or_else(|| {
561 darling::Error::custom(
562 "generic parameter T appears in the output type but not in any input type",
563 )
564 })?;
565
566 let input_access = if is_unary {
568 quote! { input_type }
569 } else {
570 let pos_lit = syn::Index::from(pos);
571 quote! { input_types[#pos_lit] }
572 };
573
574 let consistency_checks = if !is_unary {
578 let mut checks = Vec::new();
579 for (i, ty) in input_types.iter().enumerate() {
580 if i == pos {
581 continue;
582 }
583 let usage = classify_generic_usage(ty, generic_name);
584 if usage == GenericUsage::Absent {
585 continue;
586 }
587 let primary_elem = element_type_expr(&input_access, &source_usage);
588 let i_lit = syn::Index::from(i);
589 let other_access = quote! { input_types[#i_lit] };
590 let other_elem = element_type_expr(&other_access, &usage);
591 let generic_str = generic_name.to_string();
592 checks.push(quote! {
593 mz_ore::soft_assert_or_log!(
594 #primary_elem.base_eq(#other_elem),
595 "auto-derived sqlfunc output type inference found inconsistent \
596 SQL types for generic {} across inputs: {:?} vs {:?}; \
597 this indicates a bug in polymorphic coercion, builtin \
598 declaration, or sqlfunc inference",
599 #generic_str,
600 #primary_elem,
601 #other_elem,
602 );
603 });
604 }
605 quote! { #(#checks)* }
606 } else {
607 quote! {}
608 };
609
610 let expr = match (&output_usage, &source_usage) {
613 (GenericUsage::Bare, GenericUsage::InContainer(in_container)) => {
615 let in_c = elide_lifetimes(in_container);
616 quote! {
617 {
618 #consistency_checks
619 <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
620 &#input_access.scalar_type
621 ).clone().nullable(#nullable)
622 }
623 }
624 }
625 (GenericUsage::Bare, GenericUsage::Bare) => {
627 quote! {
628 {
629 #consistency_checks
630 #input_access.scalar_type.clone().nullable(#nullable)
631 }
632 }
633 }
634 (GenericUsage::InContainer(out_container), GenericUsage::InContainer(in_container)) => {
637 let out_c = elide_lifetimes(out_container);
638 let in_c = elide_lifetimes(in_container);
639 quote! {
640 {
641 #consistency_checks
642 <#out_c as mz_repr::SqlContainerType>::wrap_element_type(
643 <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
644 &#input_access.scalar_type
645 ).clone()
646 ).nullable(#nullable)
647 }
648 }
649 }
650 _ => {
652 return Err(darling::Error::custom(format!(
653 "cannot auto-derive output_type_expr: output uses T as {:?} but \
654 the first T-containing input uses T as {:?}",
655 output_usage, source_usage
656 )));
657 }
658 };
659
660 Ok(Some(expr))
661}
662
663fn element_type_expr(input_access: &TokenStream, usage: &GenericUsage) -> TokenStream {
666 match usage {
667 GenericUsage::Bare => {
668 quote! { &#input_access.scalar_type }
669 }
670 GenericUsage::InContainer(container) => {
671 let c = elide_lifetimes(container);
672 quote! {
673 <#c as mz_repr::SqlContainerType>::unwrap_element_type(
674 &#input_access.scalar_type
675 )
676 }
677 }
678 GenericUsage::Absent => unreachable!("element_type_expr called with Absent usage"),
679 }
680}
681
682fn elide_lifetimes(tp: &syn::TypePath) -> syn::TypePath {
689 let mut tp = tp.clone();
690 for segment in &mut tp.path.segments {
691 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
692 for arg in &mut args.args {
693 match arg {
694 syn::GenericArgument::Lifetime(lt) => {
695 *lt = Lifetime::new("'_", lt.span());
696 }
697 syn::GenericArgument::Type(ty) => {
698 elide_lifetimes_in_type(ty);
699 }
700 _ => {}
701 }
702 }
703 }
704 }
705 tp
706}
707
708fn elide_lifetimes_in_type(ty: &mut syn::Type) {
710 match ty {
711 syn::Type::Path(tp) => {
712 *tp = elide_lifetimes(tp);
713 }
714 syn::Type::Reference(r) => {
715 if let Some(lt) = &mut r.lifetime {
716 *lt = Lifetime::new("'_", lt.span());
717 }
718 elide_lifetimes_in_type(&mut r.elem);
719 }
720 syn::Type::Tuple(t) => {
721 for elem in &mut t.elems {
722 elide_lifetimes_in_type(elem);
723 }
724 }
725 _ => {}
726 }
727}
728
729fn erase_generic_param(ty: &syn::Type, generic_name: &Ident) -> syn::Type {
734 match ty {
735 syn::Type::Path(type_path) => {
736 if type_path.path.is_ident(generic_name) {
737 return syn::parse_quote!(Datum<'a>);
738 }
739 let mut type_path = type_path.clone();
740 for segment in &mut type_path.path.segments {
741 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
742 for arg in &mut args.args {
743 if let syn::GenericArgument::Type(inner) = arg {
744 *inner = erase_generic_param(inner, generic_name);
745 }
746 }
747 }
748 }
749 syn::Type::Path(type_path)
750 }
751 syn::Type::Reference(r) => {
752 let elem = Box::new(erase_generic_param(&r.elem, generic_name));
753 syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
754 }
755 syn::Type::Tuple(t) => {
756 let elems = t
757 .elems
758 .iter()
759 .map(|e| erase_generic_param(e, generic_name))
760 .collect();
761 syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
762 }
763 _ => ty.clone(),
764 }
765}
766
767fn erase_all_generic_params(ty: &syn::Type, generic_names: &[Ident]) -> syn::Type {
769 let mut ty = ty.clone();
770 for gn in generic_names {
771 ty = erase_generic_param(&ty, gn);
772 }
773 ty
774}
775
776fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
783 match &arg.sig.inputs[nth] {
784 syn::FnArg::Typed(pat) => {
785 if let syn::Type::Reference(r) = &*pat.ty {
787 if r.lifetime.is_none() {
788 let ty = syn::Type::Reference(syn::TypeReference {
789 lifetime: Some(Lifetime::new("'a", r.span())),
790 ..r.clone()
791 });
792 return Ok(ty);
793 }
794 }
795 Ok((*pat.ty).clone())
796 }
797 syn::FnArg::Receiver(_) => Err(syn::Error::new(
798 arg.sig.inputs[nth].span(),
799 "Unsupported argument type",
800 )),
801 }
802}
803
804fn patch_lifetimes(ty: &syn::Type) -> syn::Type {
807 match ty {
808 syn::Type::Reference(r) => {
809 let elem = Box::new(patch_lifetimes(&r.elem));
810 if r.lifetime.is_none() {
811 syn::Type::Reference(syn::TypeReference {
812 lifetime: Some(Lifetime::new("'a", r.span())),
813 elem,
814 ..r.clone()
815 })
816 } else {
817 syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
818 }
819 }
820 syn::Type::Tuple(t) => {
821 let elems = t.elems.iter().map(patch_lifetimes).collect();
822 syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
823 }
824 syn::Type::Path(p) => {
825 let mut p = p.clone();
826 for segment in &mut p.path.segments {
827 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
828 for arg in &mut args.args {
829 if let syn::GenericArgument::Type(ty) = arg {
830 *ty = patch_lifetimes(ty);
831 }
832 }
833 }
834 }
835 syn::Type::Path(p)
836 }
837 _ => ty.clone(),
838 }
839}
840
841fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
844 match &arg.sig.output {
845 syn::ReturnType::Type(_, ty) => Ok(&*ty),
846 syn::ReturnType::Default => Err(syn::Error::new(
847 arg.sig.output.span(),
848 "Function needs to return a value",
849 )),
850 }
851}
852
853fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
855 let fn_name = &func.sig.ident;
856 let struct_name = camel_case(&func.sig.ident);
857 let input_ty_raw = arg_type(func, 0)?;
858 let output_ty_raw = output_type(func)?;
859 let generic_params = find_generic_type_params(func);
860 let input_ty = erase_all_generic_params(&input_ty_raw, &generic_params);
862 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
863 let Modifiers {
864 is_monotone,
865 sqlname,
866 preserves_uniqueness,
867 inverse,
868 is_infix_op,
869 output_type,
870 mut output_type_expr,
871 negate,
872 could_error,
873 propagates_nulls,
874 mut introduces_nulls,
875 is_associative,
876 is_eliminable_cast,
877 test: _,
878 } = modifiers;
879
880 if !generic_params.is_empty() {
884 if output_type_expr.is_none() && output_type.is_none() {
885 if let Some(derived) = derive_output_type_for_generics(
886 &[input_ty_raw],
887 output_ty_raw,
888 &generic_params,
889 true,
890 )? {
891 output_type_expr = Some(syn::parse2(derived)?);
892 if introduces_nulls.is_none() {
893 let nullable = is_option_wrapped(output_ty_raw);
894 introduces_nulls = Some(syn::parse_quote!(#nullable));
895 }
896 }
897 }
898 }
899
900 if is_infix_op.is_some() {
901 return Err(darling::Error::unknown_field(
902 "is_infix_op not supported for unary functions",
903 ));
904 }
905 if output_type.is_some() && output_type_expr.is_some() {
906 return Err(darling::Error::unknown_field(
907 "output_type and output_type_expr cannot be used together",
908 ));
909 }
910 if output_type_expr.is_some() && introduces_nulls.is_none() {
911 return Err(darling::Error::unknown_field(
912 "output_type_expr requires introduces_nulls",
913 ));
914 }
915 if negate.is_some() {
916 return Err(darling::Error::unknown_field(
917 "negate not supported for unary functions",
918 ));
919 }
920 if propagates_nulls.is_some() {
921 return Err(darling::Error::unknown_field(
922 "propagates_nulls not supported for unary functions",
923 ));
924 }
925 if is_associative.is_some() {
926 return Err(darling::Error::unknown_field(
927 "is_associative not supported for unary functions",
928 ));
929 }
930
931 let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
932 quote! {
933 fn preserves_uniqueness(&self) -> bool {
934 #preserves_uniqueness
935 }
936 }
937 });
938
939 let inverse_fn = inverse.as_ref().map(|inverse| {
940 quote! {
941 fn inverse(&self) -> Option<crate::UnaryFunc> {
942 #inverse
943 }
944 }
945 });
946
947 let is_monotone_fn = is_monotone.map(|is_monotone| {
948 quote! {
949 fn is_monotone(&self) -> bool {
950 #is_monotone
951 }
952 }
953 });
954
955 let name = sqlname
956 .as_ref()
957 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
958
959 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
960 let introduces_nulls_fn = quote! {
961 fn introduces_nulls(&self) -> bool {
962 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
963 }
964 };
965 let output_type = quote! { <#output_type>::as_column_type() };
966 (output_type, Some(introduces_nulls_fn))
967 } else {
968 (quote! { Self::Output::as_column_type() }, None)
969 };
970
971 if let Some(output_type_expr) = output_type_expr {
972 output_type = quote! { #output_type_expr };
973 }
974
975 if let Some(introduces_nulls) = introduces_nulls {
976 introduces_nulls_fn = Some(quote! {
977 fn introduces_nulls(&self) -> bool {
978 #introduces_nulls
979 }
980 });
981 }
982
983 let could_error_fn = could_error.map(|could_error| {
984 quote! {
985 fn could_error(&self) -> bool {
986 #could_error
987 }
988 }
989 });
990
991 let is_eliminable_cast_fn = is_eliminable_cast.map(|is_eliminable_cast| {
992 quote! {
993 fn is_eliminable_cast(&self) -> bool {
994 #is_eliminable_cast
995 }
996 }
997 });
998
999 let result = quote! {
1000 #[derive(
1001 proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1002 Debug, Eq, PartialEq, serde::Serialize,
1003 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1004 )]
1005 pub struct #struct_name;
1006
1007 impl crate::func::EagerUnaryFunc for #struct_name {
1008 type Input<'a> = #input_ty;
1009 type Output<'a> = #output_ty;
1010
1011 fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
1012 #fn_name(a)
1013 }
1014
1015 fn output_sql_type(
1016 &self,
1017 input_type: mz_repr::SqlColumnType
1018 ) -> mz_repr::SqlColumnType {
1019 use mz_repr::AsColumnType;
1020 let output = #output_type;
1021 let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
1022 let nullable = output.nullable;
1023 output.nullable(nullable || (propagates_nulls && input_type.nullable))
1026 }
1027
1028 #could_error_fn
1029 #introduces_nulls_fn
1030 #inverse_fn
1031 #is_monotone_fn
1032 #preserves_uniqueness_fn
1033 #is_eliminable_cast_fn
1034 }
1035
1036 impl std::fmt::Display for #struct_name {
1037 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1038 f.write_str(#name)
1039 }
1040 }
1041
1042 #func
1043 };
1044 Ok(result)
1045}
1046
1047fn binary_func(
1049 func: &syn::ItemFn,
1050 modifiers: Modifiers,
1051 arena: bool,
1052) -> darling::Result<TokenStream> {
1053 let fn_name = &func.sig.ident;
1054 let struct_name = camel_case(&func.sig.ident);
1055 let input1_ty_raw = arg_type(func, 0)?;
1056 let input2_ty_raw = arg_type(func, 1)?;
1057 let output_ty_raw = output_type(func)?;
1058 let generic_params = find_generic_type_params(func);
1059 let input1_ty = erase_all_generic_params(&input1_ty_raw, &generic_params);
1061 let input2_ty = erase_all_generic_params(&input2_ty_raw, &generic_params);
1062 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1063
1064 let Modifiers {
1065 is_monotone,
1066 sqlname,
1067 preserves_uniqueness,
1068 inverse,
1069 is_infix_op,
1070 output_type,
1071 mut output_type_expr,
1072 negate,
1073 could_error,
1074 propagates_nulls,
1075 mut introduces_nulls,
1076 is_associative,
1077 is_eliminable_cast,
1078 test: _,
1079 } = modifiers;
1080
1081 if !generic_params.is_empty() {
1084 if output_type_expr.is_none() && output_type.is_none() {
1085 if let Some(derived) = derive_output_type_for_generics(
1086 &[input1_ty_raw, input2_ty_raw],
1087 output_ty_raw,
1088 &generic_params,
1089 false,
1090 )? {
1091 output_type_expr = Some(syn::parse2(derived)?);
1092 if introduces_nulls.is_none() {
1093 let nullable = is_option_wrapped(output_ty_raw);
1094 introduces_nulls = Some(syn::parse_quote!(#nullable));
1095 }
1096 }
1097 }
1098 }
1099
1100 if preserves_uniqueness.is_some() {
1101 return Err(darling::Error::unknown_field(
1102 "preserves_uniqueness not supported for binary functions",
1103 ));
1104 }
1105 if inverse.is_some() {
1106 return Err(darling::Error::unknown_field(
1107 "inverse not supported for binary functions",
1108 ));
1109 }
1110 if output_type.is_some() && output_type_expr.is_some() {
1111 return Err(darling::Error::unknown_field(
1112 "output_type and output_type_expr cannot be used together",
1113 ));
1114 }
1115 if output_type_expr.is_some() && introduces_nulls.is_none() {
1116 return Err(darling::Error::unknown_field(
1117 "output_type_expr requires introduces_nulls",
1118 ));
1119 }
1120 if is_associative.is_some() {
1121 return Err(darling::Error::unknown_field(
1122 "is_associative not supported for binary functions",
1123 ));
1124 }
1125 if is_eliminable_cast.is_some() {
1126 return Err(darling::Error::unknown_field(
1127 "is_eliminable_cast not supported for binary functions",
1128 ));
1129 }
1130
1131 let negate_fn = negate.map(|negate| {
1132 quote! {
1133 fn negate(&self) -> Option<crate::BinaryFunc> {
1134 #negate
1135 }
1136 }
1137 });
1138
1139 let is_monotone_fn = is_monotone.map(|is_monotone| {
1140 quote! {
1141 fn is_monotone(&self) -> (bool, bool) {
1142 #is_monotone
1143 }
1144 }
1145 });
1146
1147 let name = sqlname
1148 .as_ref()
1149 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1150
1151 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1152 let introduces_nulls_fn = quote! {
1153 fn introduces_nulls(&self) -> bool {
1154 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1155 }
1156 };
1157 let output_type = quote! { <#output_type>::as_column_type() };
1158 (output_type, Some(introduces_nulls_fn))
1159 } else {
1160 (quote! { Self::Output::as_column_type() }, None)
1161 };
1162
1163 if let Some(output_type_expr) = output_type_expr {
1164 output_type = quote! { #output_type_expr };
1165 }
1166
1167 if let Some(introduces_nulls) = introduces_nulls {
1168 introduces_nulls_fn = Some(quote! {
1169 fn introduces_nulls(&self) -> bool {
1170 #introduces_nulls
1171 }
1172 });
1173 }
1174
1175 let arena = if arena {
1176 quote! { , temp_storage }
1177 } else {
1178 quote! {}
1179 };
1180
1181 let could_error_fn = could_error.map(|could_error| {
1182 quote! {
1183 fn could_error(&self) -> bool {
1184 #could_error
1185 }
1186 }
1187 });
1188
1189 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1190 quote! {
1191 fn is_infix_op(&self) -> bool {
1192 #is_infix_op
1193 }
1194 }
1195 });
1196
1197 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1198 quote! {
1199 fn propagates_nulls(&self) -> bool {
1200 #propagates_nulls
1201 }
1202 }
1203 });
1204
1205 let binary_non_nullable_checks =
1208 non_nullable_position_checks(&[input1_ty.clone(), input2_ty.clone()]);
1209
1210 let result = quote! {
1211 #[derive(
1212 proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1213 Debug, Eq, PartialEq, serde::Serialize,
1214 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1215 )]
1216 pub struct #struct_name;
1217
1218 impl crate::func::binary::EagerBinaryFunc for #struct_name {
1219 type Input<'a> = (#input1_ty, #input2_ty);
1220 type Output<'a> = #output_ty;
1221
1222 fn call<'a>(
1223 &self,
1224 (a, b): Self::Input<'a>,
1225 temp_storage: &'a mz_repr::RowArena
1226 ) -> Self::Output<'a> {
1227 #fn_name(a, b #arena)
1228 }
1229
1230 fn output_sql_type(
1231 &self,
1232 input_types: &[mz_repr::SqlColumnType],
1233 ) -> mz_repr::SqlColumnType {
1234 use mz_repr::AsColumnType;
1235 let output = #output_type;
1236 let propagates_nulls =
1237 crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
1238 let nullable = output.nullable;
1239 let non_nullable_input_is_nullable =
1246 false #(#binary_non_nullable_checks)*;
1247 let inputs_nullable = input_types.iter().any(|it| it.nullable);
1248 let is_null = nullable
1249 || non_nullable_input_is_nullable
1250 || (propagates_nulls && inputs_nullable);
1251 output.nullable(is_null)
1252 }
1253
1254 #could_error_fn
1255 #introduces_nulls_fn
1256 #is_infix_op_fn
1257 #is_monotone_fn
1258 #negate_fn
1259 #propagates_nulls_fn
1260 }
1261
1262 impl std::fmt::Display for #struct_name {
1263 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1264 f.write_str(#name)
1265 }
1266 }
1267
1268 #func
1269
1270 };
1271 Ok(result)
1272}
1273
1274fn variadic_func(
1280 func: &syn::ItemFn,
1281 modifiers: Modifiers,
1282 struct_ty: Option<syn::Path>,
1283 arena: bool,
1284 has_self: bool,
1285) -> darling::Result<TokenStream> {
1286 let fn_name = &func.sig.ident;
1287 let output_ty_raw = output_type(func)?;
1288 let generic_params = find_generic_type_params(func);
1289 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1290 let struct_name = struct_ty
1291 .as_ref()
1292 .and_then(|ty| ty.segments.last())
1293 .map_or_else(|| camel_case(fn_name), |seg| seg.ident.clone());
1294
1295 let Modifiers {
1296 is_monotone,
1297 sqlname,
1298 preserves_uniqueness,
1299 inverse,
1300 is_infix_op,
1301 output_type,
1302 mut output_type_expr,
1303 negate,
1304 could_error,
1305 propagates_nulls,
1306 mut introduces_nulls,
1307 is_associative,
1308 is_eliminable_cast,
1309 test: _,
1310 } = modifiers;
1311
1312 if preserves_uniqueness.is_some() {
1314 return Err(darling::Error::unknown_field(
1315 "preserves_uniqueness not supported for variadic functions",
1316 ));
1317 }
1318 if inverse.is_some() {
1319 return Err(darling::Error::unknown_field(
1320 "inverse not supported for variadic functions",
1321 ));
1322 }
1323 if negate.is_some() {
1324 return Err(darling::Error::unknown_field(
1325 "negate not supported for variadic functions",
1326 ));
1327 }
1328 if is_eliminable_cast.is_some() {
1329 return Err(darling::Error::unknown_field(
1330 "is_eliminable_cast not supported for variadic functions",
1331 ));
1332 }
1333 if output_type.is_some() && output_type_expr.is_some() {
1334 return Err(darling::Error::unknown_field(
1335 "output_type and output_type_expr cannot be used together",
1336 ));
1337 }
1338 if output_type_expr.is_some() && introduces_nulls.is_none() {
1339 return Err(darling::Error::unknown_field(
1340 "output_type_expr requires introduces_nulls",
1341 ));
1342 }
1343
1344 let start = if has_self { 1 } else { 0 };
1346 let end = if arena {
1347 func.sig.inputs.len() - 1
1348 } else {
1349 func.sig.inputs.len()
1350 };
1351 let input_params: Vec<&syn::FnArg> = func
1352 .sig
1353 .inputs
1354 .iter()
1355 .skip(start)
1356 .take(end - start)
1357 .collect();
1358
1359 if input_params.is_empty() {
1360 return Err(darling::Error::custom(
1361 "variadic function must have at least one input parameter",
1362 ));
1363 }
1364
1365 let mut param_names = Vec::new();
1367 let mut param_types = Vec::new();
1368 for param in &input_params {
1369 match param {
1370 syn::FnArg::Typed(pat) => {
1371 if let syn::Pat::Ident(ident) = &*pat.pat {
1372 param_names.push(ident.ident.clone());
1373 } else {
1374 return Err(
1375 darling::Error::custom("unsupported parameter pattern").with_span(&pat.pat)
1376 );
1377 }
1378 param_types.push(patch_lifetimes(&pat.ty));
1379 }
1380 syn::FnArg::Receiver(_) => {
1381 return Err(darling::Error::custom("unexpected self parameter"));
1382 }
1383 }
1384 }
1385
1386 if !generic_params.is_empty() {
1389 if output_type_expr.is_none() && output_type.is_none() {
1390 if let Some(derived) = derive_output_type_for_generics(
1391 ¶m_types,
1392 output_ty_raw,
1393 &generic_params,
1394 false,
1395 )? {
1396 output_type_expr = Some(syn::parse2(derived)?);
1397 if introduces_nulls.is_none() {
1398 let nullable = is_option_wrapped(output_ty_raw);
1399 introduces_nulls = Some(syn::parse_quote!(#nullable));
1400 }
1401 }
1402 }
1403 }
1404
1405 for ty in &mut param_types {
1407 *ty = erase_all_generic_params(ty, &generic_params);
1408 }
1409
1410 let input_type: syn::Type = if param_types.len() == 1 {
1412 param_types[0].clone()
1413 } else {
1414 syn::parse_quote! { (#(#param_types),*) }
1415 };
1416
1417 let destructure = if param_names.len() == 1 {
1419 let name = ¶m_names[0];
1420 quote! { #name }
1421 } else {
1422 quote! { (#(#param_names),*) }
1423 };
1424
1425 let arena_arg = if arena {
1426 quote! { , temp_storage }
1427 } else {
1428 quote! {}
1429 };
1430
1431 let call_expr = if has_self {
1432 quote! { self.#fn_name(#(#param_names),* #arena_arg) }
1433 } else {
1434 quote! { #fn_name(#(#param_names),* #arena_arg) }
1435 };
1436
1437 let name = sqlname
1439 .as_ref()
1440 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1441
1442 let (mut output_type_code, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1443 let introduces_nulls_fn = quote! {
1444 fn introduces_nulls(&self) -> bool {
1445 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1446 }
1447 };
1448 let output_type_code = quote! { <#output_type>::as_column_type() };
1449 (output_type_code, Some(introduces_nulls_fn))
1450 } else {
1451 (quote! { Self::Output::as_column_type() }, None)
1452 };
1453
1454 if let Some(output_type_expr) = output_type_expr {
1455 output_type_code = quote! { #output_type_expr };
1456 }
1457
1458 if let Some(introduces_nulls) = introduces_nulls {
1459 introduces_nulls_fn = Some(quote! {
1460 fn introduces_nulls(&self) -> bool {
1461 #introduces_nulls
1462 }
1463 });
1464 }
1465
1466 let could_error_fn = could_error.map(|could_error| {
1467 quote! {
1468 fn could_error(&self) -> bool {
1469 #could_error
1470 }
1471 }
1472 });
1473
1474 let is_monotone_fn = is_monotone.map(|is_monotone| {
1475 quote! {
1476 fn is_monotone(&self) -> bool {
1477 #is_monotone
1478 }
1479 }
1480 });
1481
1482 let is_associative_fn = is_associative.map(|is_associative| {
1483 quote! {
1484 fn is_associative(&self) -> bool {
1485 #is_associative
1486 }
1487 }
1488 });
1489
1490 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1491 quote! {
1492 fn is_infix_op(&self) -> bool {
1493 #is_infix_op
1494 }
1495 }
1496 });
1497
1498 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1499 quote! {
1500 fn propagates_nulls(&self) -> bool {
1501 #propagates_nulls
1502 }
1503 }
1504 });
1505
1506 let non_nullable_checks = non_nullable_position_checks(¶m_types);
1509
1510 let trait_impl = quote! {
1511 impl crate::func::variadic::EagerVariadicFunc for #struct_name {
1512 type Input<'a> = #input_type;
1513 type Output<'a> = #output_ty;
1514
1515 fn call<'a>(
1516 &self,
1517 #destructure: Self::Input<'a>,
1518 temp_storage: &'a mz_repr::RowArena,
1519 ) -> Self::Output<'a> {
1520 #call_expr
1521 }
1522
1523 fn output_type(
1524 &self,
1525 input_types: &[mz_repr::SqlColumnType],
1526 ) -> mz_repr::SqlColumnType {
1527 use mz_repr::AsColumnType;
1528 let output = #output_type_code;
1529 let propagates_nulls =
1530 crate::func::variadic::EagerVariadicFunc::propagates_nulls(self);
1531 let nullable = output.nullable;
1532 let non_nullable_input_is_nullable =
1539 false #(#non_nullable_checks)*;
1540 let inputs_nullable = input_types.iter().any(|it| it.nullable);
1541 output.nullable(
1542 nullable
1543 || non_nullable_input_is_nullable
1544 || (propagates_nulls && inputs_nullable)
1545 )
1546 }
1547
1548 #could_error_fn
1549 #introduces_nulls_fn
1550 #is_infix_op_fn
1551 #is_monotone_fn
1552 #is_associative_fn
1553 #propagates_nulls_fn
1554 }
1555 };
1556
1557 let display_impl = quote! {
1558 impl std::fmt::Display for #struct_name {
1559 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1560 f.write_str(#name)
1561 }
1562 }
1563 };
1564
1565 let result = if has_self {
1566 quote! {
1568 impl #struct_name {
1569 #func
1570 }
1571 #trait_impl
1572 #display_impl
1573 }
1574 } else {
1575 quote! {
1577 #[derive(
1578 proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
1579 Debug, Eq, PartialEq, serde::Serialize,
1580 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1581 )]
1582 pub struct #struct_name;
1583
1584 #trait_impl
1585 #display_impl
1586
1587 #func
1588 }
1589 };
1590
1591 Ok(result)
1592}