1use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19 is_monotone: Option<Expr>,
22 sqlname: Option<SqlName>,
24 preserves_uniqueness: Option<Expr>,
26 inverse: Option<Expr>,
28 negate: Option<Expr>,
30 is_infix_op: Option<Expr>,
33 output_type: Option<syn::Path>,
35 output_type_expr: Option<Expr>,
37 could_error: Option<Expr>,
40 propagates_nulls: Option<Expr>,
42 introduces_nulls: Option<Expr>,
44 is_associative: Option<Expr>,
46 is_eliminable_cast: Option<Expr>,
48 test: Option<bool>,
50}
51
52#[derive(Debug)]
55enum SqlName {
56 Literal(syn::Lit),
58 Macro(syn::ExprMacro),
60}
61
62impl quote::ToTokens for SqlName {
63 fn to_tokens(&self, tokens: &mut TokenStream) {
64 let name = match self {
65 SqlName::Literal(lit) => quote! { #lit },
66 SqlName::Macro(mac) => quote! { #mac },
67 };
68 tokens.extend(name);
69 }
70}
71
72impl darling::FromMeta for SqlName {
73 fn from_value(value: &Lit) -> darling::Result<Self> {
74 Ok(Self::Literal(value.clone()))
75 }
76 fn from_expr(expr: &Expr) -> darling::Result<Self> {
77 match expr {
78 Expr::Lit(lit) => Self::from_value(&lit.lit),
79 Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
80 Expr::Group(mac) => Self::from_expr(&mac.expr),
83 _ => Err(darling::Error::unexpected_expr_type(expr)),
84 }
85 }
86}
87
88pub fn sqlfunc(
94 attr: TokenStream,
95 item: TokenStream,
96 include_test: bool,
97) -> darling::Result<TokenStream> {
98 let mut attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
99
100 let struct_ty = match attr_args.first() {
102 Some(darling::ast::NestedMeta::Meta(syn::Meta::Path(_))) => {
103 let darling::ast::NestedMeta::Meta(syn::Meta::Path(path)) = attr_args.remove(0) else {
104 unreachable!()
105 };
106 Some(path)
107 }
108 _ => None,
109 };
110
111 let modifiers = Modifiers::from_list(&attr_args).unwrap();
112 let generate_tests = modifiers.test.unwrap_or(false);
113 let func = syn::parse2::<syn::ItemFn>(item.clone())?;
114
115 let tokens = match determine_arity(&func) {
116 Arity::Nullary => Err(darling::Error::custom("Nullary functions not supported")),
117 Arity::Unary { arena: false } => unary_func(&func, modifiers),
118 Arity::Unary { arena: true } => Err(darling::Error::custom(
119 "Unary functions do not yet support RowArena.",
120 )),
121 Arity::Binary { arena } => binary_func(&func, modifiers, arena),
122 Arity::Variadic { arena, has_self } => {
123 variadic_func(&func, modifiers, struct_ty, arena, has_self)
124 }
125 }?;
126
127 let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
128
129 Ok(quote! {
130 #tokens
131 #test
132 })
133}
134
135#[cfg(any(feature = "test", test))]
136fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
137 let attr = attr.to_string();
138 let item = item.to_string();
139 let test_name = Ident::new(&format!("test_{}", name), name.span());
140 let fn_name = name.to_string();
141
142 quote! {
143 #[cfg(test)]
144 #[cfg_attr(miri, ignore)] #[mz_ore::test]
146 fn #test_name() {
147 let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
148 insta::assert_snapshot!(#fn_name, output, &input);
149 }
150 }
151}
152
153#[cfg(not(any(feature = "test", test)))]
154fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
155 quote! {}
156}
157
158fn last_is_arena(func: &syn::ItemFn) -> bool {
160 func.sig.inputs.last().map_or(false, |last| {
161 if let syn::FnArg::Typed(pat) = last {
162 if let syn::Type::Reference(reference) = &*pat.ty {
163 if let syn::Type::Path(path) = &*reference.elem {
164 return path.path.is_ident("RowArena");
165 }
166 }
167 }
168 false
169 })
170}
171
172enum Arity {
174 Nullary,
175 Unary { arena: bool },
176 Binary { arena: bool },
177 Variadic { arena: bool, has_self: bool },
178}
179
180fn is_variadic_arg(arg: &syn::FnArg) -> bool {
184 if let syn::FnArg::Typed(pat) = arg {
185 if let syn::Type::Path(path) = &*pat.ty {
186 if let Some(segment) = path.path.segments.last() {
187 let ident = segment.ident.to_string();
188 return ident == "Variadic" || ident == "OptionalArg";
189 }
190 }
191 }
192 false
193}
194
195fn determine_arity(func: &syn::ItemFn) -> Arity {
201 let arena = last_is_arena(func);
202 let has_self = matches!(func.sig.inputs.first(), Some(syn::FnArg::Receiver(_)));
203
204 let mut effective_count = func.sig.inputs.len();
205 if arena {
206 effective_count -= 1;
207 }
208 if has_self {
209 effective_count -= 1;
210 }
211
212 let start = if has_self { 1 } else { 0 };
214 let end = if arena {
215 func.sig.inputs.len() - 1
216 } else {
217 func.sig.inputs.len()
218 };
219 let has_variadic_param = func
220 .sig
221 .inputs
222 .iter()
223 .skip(start)
224 .take(end - start)
225 .any(is_variadic_arg);
226
227 if has_variadic_param || effective_count >= 3 {
228 Arity::Variadic { arena, has_self }
229 } else {
230 match effective_count {
231 0 => Arity::Nullary,
232 1 => Arity::Unary { arena },
233 2 => Arity::Binary { arena },
234 _ => unreachable!(),
235 }
236 }
237}
238
239fn is_nullable_type(ty: &syn::Type) -> bool {
246 if let syn::Type::Path(type_path) = ty {
247 if let Some(last_segment) = type_path.path.segments.last() {
248 let ident = &last_segment.ident;
249 if ident == "Option" || ident == "Datum" {
250 return true;
251 }
252 if ident == "OptionalArg" {
253 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
255 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
256 return is_nullable_type(inner_ty);
257 }
258 }
259 return false;
260 }
261 }
262 }
263 false
264}
265
266fn is_variadic_type(ty: &syn::Type) -> bool {
268 if let syn::Type::Path(type_path) = ty {
269 if let Some(last_segment) = type_path.path.segments.last() {
270 return last_segment.ident == "Variadic";
271 }
272 }
273 false
274}
275
276fn variadic_element_is_nullable(ty: &syn::Type) -> bool {
278 if let syn::Type::Path(type_path) = ty {
279 if let Some(last_segment) = type_path.path.segments.last() {
280 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
281 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
282 return is_nullable_type(inner_ty);
283 }
284 }
285 }
286 }
287 false
288}
289
290fn non_nullable_position_checks(param_types: &[syn::Type]) -> Vec<TokenStream> {
296 let mut checks = Vec::new();
297 for (i, ty) in param_types.iter().enumerate() {
298 if is_variadic_type(ty) {
299 if !variadic_element_is_nullable(ty) {
300 checks.push(quote! { || input_types.iter().skip(#i).any(|t| t.nullable) });
301 }
302 } else if !is_nullable_type(ty) {
303 checks.push(quote! { || input_types.get(#i).map_or(false, |t| t.nullable) });
304 }
305 }
306 checks
307}
308
309fn camel_case(ident: &Ident) -> Ident {
310 let mut result = String::new();
311 let mut capitalize_next = true;
312 for c in ident.to_string().chars() {
313 if c == '_' {
314 capitalize_next = true;
315 } else if capitalize_next {
316 result.push(c.to_ascii_uppercase());
317 capitalize_next = false;
318 } else {
319 result.push(c);
320 }
321 }
322 Ident::new(&result, ident.span())
323}
324
325fn find_generic_type_params(func: &syn::ItemFn) -> Vec<Ident> {
328 func.sig
329 .generics
330 .params
331 .iter()
332 .filter_map(|p| {
333 if let syn::GenericParam::Type(tp) = p {
334 Some(tp.ident.clone())
335 } else {
336 None
337 }
338 })
339 .collect()
340}
341
342#[derive(Debug, Clone)]
344enum GenericUsage {
345 Absent,
347 Bare,
349 InContainer(syn::TypePath),
352}
353
354impl PartialEq for GenericUsage {
355 fn eq(&self, other: &Self) -> bool {
356 match (self, other) {
357 (GenericUsage::Absent, GenericUsage::Absent) => true,
358 (GenericUsage::Bare, GenericUsage::Bare) => true,
359 (GenericUsage::InContainer(a), GenericUsage::InContainer(b)) => {
360 container_idents_match(a, b)
361 }
362 _ => false,
363 }
364 }
365}
366
367impl Eq for GenericUsage {}
368
369fn container_idents_match(a: &syn::TypePath, b: &syn::TypePath) -> bool {
375 let a_idents: Vec<_> = a.path.segments.iter().map(|s| &s.ident).collect();
376 let b_idents: Vec<_> = b.path.segments.iter().map(|s| &s.ident).collect();
377 a_idents == b_idents
378}
379
380fn classify_generic_usage(ty: &syn::Type, generic_name: &Ident) -> GenericUsage {
387 match ty {
388 syn::Type::Path(type_path) => {
389 if type_path.path.is_ident(generic_name) {
390 return GenericUsage::Bare;
391 }
392 if let Some(last) = type_path.path.segments.last() {
393 let ident_str = last.ident.to_string();
394 if ident_str == "Option" || ident_str == "Result" || ident_str == "ExcludeNull" {
396 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
397 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
398 return classify_generic_usage(inner, generic_name);
399 }
400 }
401 }
402 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
405 let has_generic_arg = args.args.iter().any(|arg| {
406 if let syn::GenericArgument::Type(inner) = arg {
407 type_contains_ident(inner, generic_name)
408 } else {
409 false
410 }
411 });
412 if has_generic_arg {
413 let erased = erase_generic_param(ty, generic_name);
415 if let syn::Type::Path(erased_path) = erased {
416 return GenericUsage::InContainer(erased_path);
417 }
418 }
419 for arg in &args.args {
423 if let syn::GenericArgument::Type(inner) = arg {
424 let inner_usage = classify_generic_usage(inner, generic_name);
425 if inner_usage != GenericUsage::Absent {
426 return inner_usage;
427 }
428 }
429 }
430 }
431 }
432 GenericUsage::Absent
433 }
434 syn::Type::Reference(r) => classify_generic_usage(&r.elem, generic_name),
435 syn::Type::Tuple(t) => {
436 let mut best = GenericUsage::Absent;
439 for elem in &t.elems {
440 let usage = classify_generic_usage(elem, generic_name);
441 match (&best, &usage) {
442 (GenericUsage::Absent, _) => best = usage,
443 (GenericUsage::Bare, u) if *u != GenericUsage::Absent => best = usage.clone(),
444 _ => {
445 if usage != GenericUsage::Absent && usage != best {
446 return GenericUsage::Bare;
448 }
449 }
450 }
451 }
452 best
453 }
454 _ => GenericUsage::Absent,
455 }
456}
457
458fn type_contains_ident(ty: &syn::Type, ident: &Ident) -> bool {
460 match ty {
461 syn::Type::Path(type_path) => {
462 if type_path.path.is_ident(ident) {
463 return true;
464 }
465 if let Some(last) = type_path.path.segments.last() {
466 if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
467 return args.args.iter().any(|arg| {
468 if let syn::GenericArgument::Type(inner) = arg {
469 type_contains_ident(inner, ident)
470 } else {
471 false
472 }
473 });
474 }
475 }
476 false
477 }
478 syn::Type::Reference(r) => type_contains_ident(&r.elem, ident),
479 syn::Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)),
480 _ => false,
481 }
482}
483
484fn is_option_wrapped(ty: &syn::Type) -> bool {
486 if let syn::Type::Path(type_path) = ty {
487 if let Some(last) = type_path.path.segments.last() {
488 return last.ident == "Option";
489 }
490 }
491 false
492}
493
494fn derive_output_type_for_generics(
506 input_types: &[syn::Type],
507 output_ty: &syn::Type,
508 generic_names: &[Ident],
509 is_unary: bool,
510) -> darling::Result<Option<TokenStream>> {
511 let generic_name = match generic_names
513 .iter()
514 .find(|gn| classify_generic_usage(output_ty, gn) != GenericUsage::Absent)
515 {
516 Some(gn) => gn,
517 None => return Ok(None),
518 };
519 derive_output_type_for_generic(input_types, output_ty, generic_name, is_unary)
520}
521
522fn derive_output_type_for_generic(
528 input_types: &[syn::Type],
529 output_ty: &syn::Type,
530 generic_name: &Ident,
531 is_unary: bool,
532) -> darling::Result<Option<TokenStream>> {
533 let output_usage = classify_generic_usage(output_ty, generic_name);
534 if output_usage == GenericUsage::Absent {
535 return Ok(None);
536 }
537
538 let nullable = is_option_wrapped(output_ty);
539
540 let mut container_input: Option<(usize, GenericUsage)> = None;
543 for (i, ty) in input_types.iter().enumerate() {
544 let usage = classify_generic_usage(ty, generic_name);
545 match &usage {
546 GenericUsage::InContainer(_) => {
547 container_input = Some((i, usage));
548 break;
549 }
550 GenericUsage::Bare => {
551 if container_input.is_none() {
553 container_input = Some((i, usage));
554 }
555 }
556 GenericUsage::Absent => {}
557 }
558 }
559
560 let (pos, source_usage) = container_input.ok_or_else(|| {
561 darling::Error::custom(
562 "generic parameter T appears in the output type but not in any input type",
563 )
564 })?;
565
566 let input_access = if is_unary {
568 quote! { input_type }
569 } else {
570 let pos_lit = syn::Index::from(pos);
571 quote! { input_types[#pos_lit] }
572 };
573
574 let consistency_checks = if !is_unary {
578 let mut checks = Vec::new();
579 for (i, ty) in input_types.iter().enumerate() {
580 if i == pos {
581 continue;
582 }
583 let usage = classify_generic_usage(ty, generic_name);
584 if usage == GenericUsage::Absent {
585 continue;
586 }
587 let primary_elem = element_type_expr(&input_access, &source_usage);
588 let i_lit = syn::Index::from(i);
589 let other_access = quote! { input_types[#i_lit] };
590 let other_elem = element_type_expr(&other_access, &usage);
591 let generic_str = generic_name.to_string();
592 checks.push(quote! {
593 mz_ore::soft_assert_or_log!(
594 #primary_elem.base_eq(#other_elem),
595 "auto-derived sqlfunc output type inference found inconsistent \
596 SQL types for generic {} across inputs: {:?} vs {:?}; \
597 this indicates a bug in polymorphic coercion, builtin \
598 declaration, or sqlfunc inference",
599 #generic_str,
600 #primary_elem,
601 #other_elem,
602 );
603 });
604 }
605 quote! { #(#checks)* }
606 } else {
607 quote! {}
608 };
609
610 let expr = match (&output_usage, &source_usage) {
613 (GenericUsage::Bare, GenericUsage::InContainer(in_container)) => {
615 let in_c = elide_lifetimes(in_container);
616 quote! {
617 {
618 #consistency_checks
619 <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
620 &#input_access.scalar_type
621 ).clone().nullable(#nullable)
622 }
623 }
624 }
625 (GenericUsage::Bare, GenericUsage::Bare) => {
627 quote! {
628 {
629 #consistency_checks
630 #input_access.scalar_type.clone().nullable(#nullable)
631 }
632 }
633 }
634 (GenericUsage::InContainer(out_container), GenericUsage::InContainer(in_container)) => {
637 let out_c = elide_lifetimes(out_container);
638 let in_c = elide_lifetimes(in_container);
639 quote! {
640 {
641 #consistency_checks
642 <#out_c as mz_repr::SqlContainerType>::wrap_element_type(
643 <#in_c as mz_repr::SqlContainerType>::unwrap_element_type(
644 &#input_access.scalar_type
645 ).clone()
646 ).nullable(#nullable)
647 }
648 }
649 }
650 _ => {
652 return Err(darling::Error::custom(format!(
653 "cannot auto-derive output_type_expr: output uses T as {:?} but \
654 the first T-containing input uses T as {:?}",
655 output_usage, source_usage
656 )));
657 }
658 };
659
660 Ok(Some(expr))
661}
662
663fn element_type_expr(input_access: &TokenStream, usage: &GenericUsage) -> TokenStream {
666 match usage {
667 GenericUsage::Bare => {
668 quote! { &#input_access.scalar_type }
669 }
670 GenericUsage::InContainer(container) => {
671 let c = elide_lifetimes(container);
672 quote! {
673 <#c as mz_repr::SqlContainerType>::unwrap_element_type(
674 &#input_access.scalar_type
675 )
676 }
677 }
678 GenericUsage::Absent => unreachable!("element_type_expr called with Absent usage"),
679 }
680}
681
682fn elide_lifetimes(tp: &syn::TypePath) -> syn::TypePath {
689 let mut tp = tp.clone();
690 for segment in &mut tp.path.segments {
691 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
692 for arg in &mut args.args {
693 match arg {
694 syn::GenericArgument::Lifetime(lt) => {
695 *lt = Lifetime::new("'_", lt.span());
696 }
697 syn::GenericArgument::Type(ty) => {
698 elide_lifetimes_in_type(ty);
699 }
700 _ => {}
701 }
702 }
703 }
704 }
705 tp
706}
707
708fn elide_lifetimes_in_type(ty: &mut syn::Type) {
710 match ty {
711 syn::Type::Path(tp) => {
712 *tp = elide_lifetimes(tp);
713 }
714 syn::Type::Reference(r) => {
715 if let Some(lt) = &mut r.lifetime {
716 *lt = Lifetime::new("'_", lt.span());
717 }
718 elide_lifetimes_in_type(&mut r.elem);
719 }
720 syn::Type::Tuple(t) => {
721 for elem in &mut t.elems {
722 elide_lifetimes_in_type(elem);
723 }
724 }
725 _ => {}
726 }
727}
728
729fn erase_generic_param(ty: &syn::Type, generic_name: &Ident) -> syn::Type {
734 match ty {
735 syn::Type::Path(type_path) => {
736 if type_path.path.is_ident(generic_name) {
737 return syn::parse_quote!(Datum<'a>);
738 }
739 let mut type_path = type_path.clone();
740 for segment in &mut type_path.path.segments {
741 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
742 for arg in &mut args.args {
743 if let syn::GenericArgument::Type(inner) = arg {
744 *inner = erase_generic_param(inner, generic_name);
745 }
746 }
747 }
748 }
749 syn::Type::Path(type_path)
750 }
751 syn::Type::Reference(r) => {
752 let elem = Box::new(erase_generic_param(&r.elem, generic_name));
753 syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
754 }
755 syn::Type::Tuple(t) => {
756 let elems = t
757 .elems
758 .iter()
759 .map(|e| erase_generic_param(e, generic_name))
760 .collect();
761 syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
762 }
763 _ => ty.clone(),
764 }
765}
766
767fn erase_all_generic_params(ty: &syn::Type, generic_names: &[Ident]) -> syn::Type {
769 let mut ty = ty.clone();
770 for gn in generic_names {
771 ty = erase_generic_param(&ty, gn);
772 }
773 ty
774}
775
776fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
783 match &arg.sig.inputs[nth] {
784 syn::FnArg::Typed(pat) => {
785 if let syn::Type::Reference(r) = &*pat.ty {
787 if r.lifetime.is_none() {
788 let ty = syn::Type::Reference(syn::TypeReference {
789 lifetime: Some(Lifetime::new("'a", r.span())),
790 ..r.clone()
791 });
792 return Ok(ty);
793 }
794 }
795 Ok((*pat.ty).clone())
796 }
797 syn::FnArg::Receiver(_) => Err(syn::Error::new(
798 arg.sig.inputs[nth].span(),
799 "Unsupported argument type",
800 )),
801 }
802}
803
804fn patch_lifetimes(ty: &syn::Type) -> syn::Type {
807 match ty {
808 syn::Type::Reference(r) => {
809 let elem = Box::new(patch_lifetimes(&r.elem));
810 if r.lifetime.is_none() {
811 syn::Type::Reference(syn::TypeReference {
812 lifetime: Some(Lifetime::new("'a", r.span())),
813 elem,
814 ..r.clone()
815 })
816 } else {
817 syn::Type::Reference(syn::TypeReference { elem, ..r.clone() })
818 }
819 }
820 syn::Type::Tuple(t) => {
821 let elems = t.elems.iter().map(patch_lifetimes).collect();
822 syn::Type::Tuple(syn::TypeTuple { elems, ..t.clone() })
823 }
824 syn::Type::Path(p) => {
825 let mut p = p.clone();
826 for segment in &mut p.path.segments {
827 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
828 for arg in &mut args.args {
829 if let syn::GenericArgument::Type(ty) = arg {
830 *ty = patch_lifetimes(ty);
831 }
832 }
833 }
834 }
835 syn::Type::Path(p)
836 }
837 _ => ty.clone(),
838 }
839}
840
841fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
844 match &arg.sig.output {
845 syn::ReturnType::Type(_, ty) => Ok(&*ty),
846 syn::ReturnType::Default => Err(syn::Error::new(
847 arg.sig.output.span(),
848 "Function needs to return a value",
849 )),
850 }
851}
852
853fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
855 let fn_name = &func.sig.ident;
856 let struct_name = camel_case(&func.sig.ident);
857 let input_ty_raw = arg_type(func, 0)?;
858 let output_ty_raw = output_type(func)?;
859 let generic_params = find_generic_type_params(func);
860 let input_ty = erase_all_generic_params(&input_ty_raw, &generic_params);
862 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
863 let Modifiers {
864 is_monotone,
865 sqlname,
866 preserves_uniqueness,
867 inverse,
868 is_infix_op,
869 output_type,
870 mut output_type_expr,
871 negate,
872 could_error,
873 propagates_nulls,
874 mut introduces_nulls,
875 is_associative,
876 is_eliminable_cast,
877 test: _,
878 } = modifiers;
879
880 if !generic_params.is_empty() {
884 if output_type_expr.is_none() && output_type.is_none() {
885 if let Some(derived) = derive_output_type_for_generics(
886 &[input_ty_raw],
887 output_ty_raw,
888 &generic_params,
889 true,
890 )? {
891 output_type_expr = Some(syn::parse2(derived)?);
892 if introduces_nulls.is_none() {
893 let nullable = is_option_wrapped(output_ty_raw);
894 introduces_nulls = Some(syn::parse_quote!(#nullable));
895 }
896 }
897 }
898 }
899
900 if is_infix_op.is_some() {
901 return Err(darling::Error::unknown_field(
902 "is_infix_op not supported for unary functions",
903 ));
904 }
905 if output_type.is_some() && output_type_expr.is_some() {
906 return Err(darling::Error::unknown_field(
907 "output_type and output_type_expr cannot be used together",
908 ));
909 }
910 if output_type_expr.is_some() && introduces_nulls.is_none() {
911 return Err(darling::Error::unknown_field(
912 "output_type_expr requires introduces_nulls",
913 ));
914 }
915 if negate.is_some() {
916 return Err(darling::Error::unknown_field(
917 "negate not supported for unary functions",
918 ));
919 }
920 if propagates_nulls.is_some() {
921 return Err(darling::Error::unknown_field(
922 "propagates_nulls not supported for unary functions",
923 ));
924 }
925 if is_associative.is_some() {
926 return Err(darling::Error::unknown_field(
927 "is_associative not supported for unary functions",
928 ));
929 }
930
931 let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
932 quote! {
933 fn preserves_uniqueness(&self) -> bool {
934 #preserves_uniqueness
935 }
936 }
937 });
938
939 let inverse_fn = inverse.as_ref().map(|inverse| {
940 quote! {
941 fn inverse(&self) -> Option<crate::UnaryFunc> {
942 #inverse
943 }
944 }
945 });
946
947 let is_monotone_fn = is_monotone.map(|is_monotone| {
948 quote! {
949 fn is_monotone(&self) -> bool {
950 #is_monotone
951 }
952 }
953 });
954
955 let name = sqlname
956 .as_ref()
957 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
958
959 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
960 let introduces_nulls_fn = quote! {
961 fn introduces_nulls(&self) -> bool {
962 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
963 }
964 };
965 let output_type = quote! { <#output_type>::as_column_type() };
966 (output_type, Some(introduces_nulls_fn))
967 } else {
968 (quote! { Self::Output::as_column_type() }, None)
969 };
970
971 if let Some(output_type_expr) = output_type_expr {
972 output_type = quote! { #output_type_expr };
973 }
974
975 if let Some(introduces_nulls) = introduces_nulls {
976 introduces_nulls_fn = Some(quote! {
977 fn introduces_nulls(&self) -> bool {
978 #introduces_nulls
979 }
980 });
981 }
982
983 let could_error_fn = could_error.map(|could_error| {
984 quote! {
985 fn could_error(&self) -> bool {
986 #could_error
987 }
988 }
989 });
990
991 let is_eliminable_cast_fn = is_eliminable_cast.map(|is_eliminable_cast| {
992 quote! {
993 fn is_eliminable_cast(&self) -> bool {
994 #is_eliminable_cast
995 }
996 }
997 });
998
999 let result = quote! {
1000 #[derive(
1001 Ord, PartialOrd, Clone,
1002 Debug, Eq, PartialEq, serde::Serialize,
1003 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1004 )]
1005 #[cfg_attr(any(test, feature = "proptest"), derive(proptest_derive::Arbitrary))]
1006 pub struct #struct_name;
1007
1008 impl crate::func::EagerUnaryFunc for #struct_name {
1009 type Input<'a> = #input_ty;
1010 type Output<'a> = #output_ty;
1011
1012 fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
1013 #fn_name(a)
1014 }
1015
1016 fn output_sql_type(
1017 &self,
1018 input_type: mz_repr::SqlColumnType
1019 ) -> mz_repr::SqlColumnType {
1020 use mz_repr::AsColumnType;
1021 let output = #output_type;
1022 let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
1023 let nullable = output.nullable;
1024 output.nullable(nullable || (propagates_nulls && input_type.nullable))
1027 }
1028
1029 #could_error_fn
1030 #introduces_nulls_fn
1031 #inverse_fn
1032 #is_monotone_fn
1033 #preserves_uniqueness_fn
1034 #is_eliminable_cast_fn
1035 }
1036
1037 impl std::fmt::Display for #struct_name {
1038 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1039 f.write_str(#name)
1040 }
1041 }
1042
1043 #func
1044 };
1045 Ok(result)
1046}
1047
1048fn binary_func(
1050 func: &syn::ItemFn,
1051 modifiers: Modifiers,
1052 arena: bool,
1053) -> darling::Result<TokenStream> {
1054 let fn_name = &func.sig.ident;
1055 let struct_name = camel_case(&func.sig.ident);
1056 let input1_ty_raw = arg_type(func, 0)?;
1057 let input2_ty_raw = arg_type(func, 1)?;
1058 let output_ty_raw = output_type(func)?;
1059 let generic_params = find_generic_type_params(func);
1060 let input1_ty = erase_all_generic_params(&input1_ty_raw, &generic_params);
1062 let input2_ty = erase_all_generic_params(&input2_ty_raw, &generic_params);
1063 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1064
1065 let Modifiers {
1066 is_monotone,
1067 sqlname,
1068 preserves_uniqueness,
1069 inverse,
1070 is_infix_op,
1071 output_type,
1072 mut output_type_expr,
1073 negate,
1074 could_error,
1075 propagates_nulls,
1076 mut introduces_nulls,
1077 is_associative,
1078 is_eliminable_cast,
1079 test: _,
1080 } = modifiers;
1081
1082 if !generic_params.is_empty() {
1085 if output_type_expr.is_none() && output_type.is_none() {
1086 if let Some(derived) = derive_output_type_for_generics(
1087 &[input1_ty_raw, input2_ty_raw],
1088 output_ty_raw,
1089 &generic_params,
1090 false,
1091 )? {
1092 output_type_expr = Some(syn::parse2(derived)?);
1093 if introduces_nulls.is_none() {
1094 let nullable = is_option_wrapped(output_ty_raw);
1095 introduces_nulls = Some(syn::parse_quote!(#nullable));
1096 }
1097 }
1098 }
1099 }
1100
1101 if preserves_uniqueness.is_some() {
1102 return Err(darling::Error::unknown_field(
1103 "preserves_uniqueness not supported for binary functions",
1104 ));
1105 }
1106 if inverse.is_some() {
1107 return Err(darling::Error::unknown_field(
1108 "inverse not supported for binary functions",
1109 ));
1110 }
1111 if output_type.is_some() && output_type_expr.is_some() {
1112 return Err(darling::Error::unknown_field(
1113 "output_type and output_type_expr cannot be used together",
1114 ));
1115 }
1116 if output_type_expr.is_some() && introduces_nulls.is_none() {
1117 return Err(darling::Error::unknown_field(
1118 "output_type_expr requires introduces_nulls",
1119 ));
1120 }
1121 if is_associative.is_some() {
1122 return Err(darling::Error::unknown_field(
1123 "is_associative not supported for binary functions",
1124 ));
1125 }
1126 if is_eliminable_cast.is_some() {
1127 return Err(darling::Error::unknown_field(
1128 "is_eliminable_cast not supported for binary functions",
1129 ));
1130 }
1131
1132 let negate_fn = negate.map(|negate| {
1133 quote! {
1134 fn negate(&self) -> Option<crate::BinaryFunc> {
1135 #negate
1136 }
1137 }
1138 });
1139
1140 let is_monotone_fn = is_monotone.map(|is_monotone| {
1141 quote! {
1142 fn is_monotone(&self) -> (bool, bool) {
1143 #is_monotone
1144 }
1145 }
1146 });
1147
1148 let name = sqlname
1149 .as_ref()
1150 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1151
1152 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1153 let introduces_nulls_fn = quote! {
1154 fn introduces_nulls(&self) -> bool {
1155 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1156 }
1157 };
1158 let output_type = quote! { <#output_type>::as_column_type() };
1159 (output_type, Some(introduces_nulls_fn))
1160 } else {
1161 (quote! { Self::Output::as_column_type() }, None)
1162 };
1163
1164 if let Some(output_type_expr) = output_type_expr {
1165 output_type = quote! { #output_type_expr };
1166 }
1167
1168 if let Some(introduces_nulls) = introduces_nulls {
1169 introduces_nulls_fn = Some(quote! {
1170 fn introduces_nulls(&self) -> bool {
1171 #introduces_nulls
1172 }
1173 });
1174 }
1175
1176 let arena = if arena {
1177 quote! { , temp_storage }
1178 } else {
1179 quote! {}
1180 };
1181
1182 let could_error_fn = could_error.map(|could_error| {
1183 quote! {
1184 fn could_error(&self) -> bool {
1185 #could_error
1186 }
1187 }
1188 });
1189
1190 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1191 quote! {
1192 fn is_infix_op(&self) -> bool {
1193 #is_infix_op
1194 }
1195 }
1196 });
1197
1198 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1199 quote! {
1200 fn propagates_nulls(&self) -> bool {
1201 #propagates_nulls
1202 }
1203 }
1204 });
1205
1206 let binary_non_nullable_checks =
1209 non_nullable_position_checks(&[input1_ty.clone(), input2_ty.clone()]);
1210
1211 let result = quote! {
1212 #[derive(
1213 Ord, PartialOrd, Clone,
1214 Debug, Eq, PartialEq, serde::Serialize,
1215 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1216 )]
1217 #[cfg_attr(any(test, feature = "proptest"), derive(proptest_derive::Arbitrary))]
1218 pub struct #struct_name;
1219
1220 impl crate::func::binary::EagerBinaryFunc for #struct_name {
1221 type Input<'a> = (#input1_ty, #input2_ty);
1222 type Output<'a> = #output_ty;
1223
1224 fn call<'a>(
1225 &self,
1226 (a, b): Self::Input<'a>,
1227 temp_storage: &'a mz_repr::RowArena
1228 ) -> Self::Output<'a> {
1229 #fn_name(a, b #arena)
1230 }
1231
1232 fn output_sql_type(
1233 &self,
1234 input_types: &[mz_repr::SqlColumnType],
1235 ) -> mz_repr::SqlColumnType {
1236 use mz_repr::AsColumnType;
1237 let output = #output_type;
1238 let propagates_nulls =
1239 crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
1240 let nullable = output.nullable;
1241 let non_nullable_input_is_nullable =
1248 false #(#binary_non_nullable_checks)*;
1249 let inputs_nullable = input_types.iter().any(|it| it.nullable);
1250 let is_null = nullable
1251 || non_nullable_input_is_nullable
1252 || (propagates_nulls && inputs_nullable);
1253 output.nullable(is_null)
1254 }
1255
1256 #could_error_fn
1257 #introduces_nulls_fn
1258 #is_infix_op_fn
1259 #is_monotone_fn
1260 #negate_fn
1261 #propagates_nulls_fn
1262 }
1263
1264 impl std::fmt::Display for #struct_name {
1265 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1266 f.write_str(#name)
1267 }
1268 }
1269
1270 #func
1271
1272 };
1273 Ok(result)
1274}
1275
1276fn variadic_func(
1282 func: &syn::ItemFn,
1283 modifiers: Modifiers,
1284 struct_ty: Option<syn::Path>,
1285 arena: bool,
1286 has_self: bool,
1287) -> darling::Result<TokenStream> {
1288 let fn_name = &func.sig.ident;
1289 let output_ty_raw = output_type(func)?;
1290 let generic_params = find_generic_type_params(func);
1291 let output_ty = erase_all_generic_params(output_ty_raw, &generic_params);
1292 let struct_name = struct_ty
1293 .as_ref()
1294 .and_then(|ty| ty.segments.last())
1295 .map_or_else(|| camel_case(fn_name), |seg| seg.ident.clone());
1296
1297 let Modifiers {
1298 is_monotone,
1299 sqlname,
1300 preserves_uniqueness,
1301 inverse,
1302 is_infix_op,
1303 output_type,
1304 mut output_type_expr,
1305 negate,
1306 could_error,
1307 propagates_nulls,
1308 mut introduces_nulls,
1309 is_associative,
1310 is_eliminable_cast,
1311 test: _,
1312 } = modifiers;
1313
1314 if preserves_uniqueness.is_some() {
1316 return Err(darling::Error::unknown_field(
1317 "preserves_uniqueness not supported for variadic functions",
1318 ));
1319 }
1320 if inverse.is_some() {
1321 return Err(darling::Error::unknown_field(
1322 "inverse not supported for variadic functions",
1323 ));
1324 }
1325 if negate.is_some() {
1326 return Err(darling::Error::unknown_field(
1327 "negate not supported for variadic functions",
1328 ));
1329 }
1330 if is_eliminable_cast.is_some() {
1331 return Err(darling::Error::unknown_field(
1332 "is_eliminable_cast not supported for variadic functions",
1333 ));
1334 }
1335 if output_type.is_some() && output_type_expr.is_some() {
1336 return Err(darling::Error::unknown_field(
1337 "output_type and output_type_expr cannot be used together",
1338 ));
1339 }
1340 if output_type_expr.is_some() && introduces_nulls.is_none() {
1341 return Err(darling::Error::unknown_field(
1342 "output_type_expr requires introduces_nulls",
1343 ));
1344 }
1345
1346 let start = if has_self { 1 } else { 0 };
1348 let end = if arena {
1349 func.sig.inputs.len() - 1
1350 } else {
1351 func.sig.inputs.len()
1352 };
1353 let input_params: Vec<&syn::FnArg> = func
1354 .sig
1355 .inputs
1356 .iter()
1357 .skip(start)
1358 .take(end - start)
1359 .collect();
1360
1361 if input_params.is_empty() {
1362 return Err(darling::Error::custom(
1363 "variadic function must have at least one input parameter",
1364 ));
1365 }
1366
1367 let mut param_names = Vec::new();
1369 let mut param_types = Vec::new();
1370 for param in &input_params {
1371 match param {
1372 syn::FnArg::Typed(pat) => {
1373 if let syn::Pat::Ident(ident) = &*pat.pat {
1374 param_names.push(ident.ident.clone());
1375 } else {
1376 return Err(
1377 darling::Error::custom("unsupported parameter pattern").with_span(&pat.pat)
1378 );
1379 }
1380 param_types.push(patch_lifetimes(&pat.ty));
1381 }
1382 syn::FnArg::Receiver(_) => {
1383 return Err(darling::Error::custom("unexpected self parameter"));
1384 }
1385 }
1386 }
1387
1388 if !generic_params.is_empty() {
1391 if output_type_expr.is_none() && output_type.is_none() {
1392 if let Some(derived) = derive_output_type_for_generics(
1393 ¶m_types,
1394 output_ty_raw,
1395 &generic_params,
1396 false,
1397 )? {
1398 output_type_expr = Some(syn::parse2(derived)?);
1399 if introduces_nulls.is_none() {
1400 let nullable = is_option_wrapped(output_ty_raw);
1401 introduces_nulls = Some(syn::parse_quote!(#nullable));
1402 }
1403 }
1404 }
1405 }
1406
1407 for ty in &mut param_types {
1409 *ty = erase_all_generic_params(ty, &generic_params);
1410 }
1411
1412 let input_type: syn::Type = if param_types.len() == 1 {
1414 param_types[0].clone()
1415 } else {
1416 syn::parse_quote! { (#(#param_types),*) }
1417 };
1418
1419 let destructure = if param_names.len() == 1 {
1421 let name = ¶m_names[0];
1422 quote! { #name }
1423 } else {
1424 quote! { (#(#param_names),*) }
1425 };
1426
1427 let arena_arg = if arena {
1428 quote! { , temp_storage }
1429 } else {
1430 quote! {}
1431 };
1432
1433 let call_expr = if has_self {
1434 quote! { self.#fn_name(#(#param_names),* #arena_arg) }
1435 } else {
1436 quote! { #fn_name(#(#param_names),* #arena_arg) }
1437 };
1438
1439 let name = sqlname
1441 .as_ref()
1442 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
1443
1444 let (mut output_type_code, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
1445 let introduces_nulls_fn = quote! {
1446 fn introduces_nulls(&self) -> bool {
1447 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
1448 }
1449 };
1450 let output_type_code = quote! { <#output_type>::as_column_type() };
1451 (output_type_code, Some(introduces_nulls_fn))
1452 } else {
1453 (quote! { Self::Output::as_column_type() }, None)
1454 };
1455
1456 if let Some(output_type_expr) = output_type_expr {
1457 output_type_code = quote! { #output_type_expr };
1458 }
1459
1460 if let Some(introduces_nulls) = introduces_nulls {
1461 introduces_nulls_fn = Some(quote! {
1462 fn introduces_nulls(&self) -> bool {
1463 #introduces_nulls
1464 }
1465 });
1466 }
1467
1468 let could_error_fn = could_error.map(|could_error| {
1469 quote! {
1470 fn could_error(&self) -> bool {
1471 #could_error
1472 }
1473 }
1474 });
1475
1476 let is_monotone_fn = is_monotone.map(|is_monotone| {
1477 quote! {
1478 fn is_monotone(&self) -> bool {
1479 #is_monotone
1480 }
1481 }
1482 });
1483
1484 let is_associative_fn = is_associative.map(|is_associative| {
1485 quote! {
1486 fn is_associative(&self) -> bool {
1487 #is_associative
1488 }
1489 }
1490 });
1491
1492 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
1493 quote! {
1494 fn is_infix_op(&self) -> bool {
1495 #is_infix_op
1496 }
1497 }
1498 });
1499
1500 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
1501 quote! {
1502 fn propagates_nulls(&self) -> bool {
1503 #propagates_nulls
1504 }
1505 }
1506 });
1507
1508 let non_nullable_checks = non_nullable_position_checks(¶m_types);
1511
1512 let trait_impl = quote! {
1513 impl crate::func::variadic::EagerVariadicFunc for #struct_name {
1514 type Input<'a> = #input_type;
1515 type Output<'a> = #output_ty;
1516
1517 fn call<'a>(
1518 &self,
1519 #destructure: Self::Input<'a>,
1520 temp_storage: &'a mz_repr::RowArena,
1521 ) -> Self::Output<'a> {
1522 #call_expr
1523 }
1524
1525 fn output_type(
1526 &self,
1527 input_types: &[mz_repr::SqlColumnType],
1528 ) -> mz_repr::SqlColumnType {
1529 use mz_repr::AsColumnType;
1530 let output = #output_type_code;
1531 let propagates_nulls =
1532 crate::func::variadic::EagerVariadicFunc::propagates_nulls(self);
1533 let nullable = output.nullable;
1534 let non_nullable_input_is_nullable =
1541 false #(#non_nullable_checks)*;
1542 let inputs_nullable = input_types.iter().any(|it| it.nullable);
1543 output.nullable(
1544 nullable
1545 || non_nullable_input_is_nullable
1546 || (propagates_nulls && inputs_nullable)
1547 )
1548 }
1549
1550 #could_error_fn
1551 #introduces_nulls_fn
1552 #is_infix_op_fn
1553 #is_monotone_fn
1554 #is_associative_fn
1555 #propagates_nulls_fn
1556 }
1557 };
1558
1559 let display_impl = quote! {
1560 impl std::fmt::Display for #struct_name {
1561 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1562 f.write_str(#name)
1563 }
1564 }
1565 };
1566
1567 let result = if has_self {
1568 quote! {
1570 impl #struct_name {
1571 #func
1572 }
1573 #trait_impl
1574 #display_impl
1575 }
1576 } else {
1577 quote! {
1579 #[derive(
1580 Ord, PartialOrd, Clone,
1581 Debug, Eq, PartialEq, serde::Serialize,
1582 serde::Deserialize, Hash, mz_lowertest::MzReflect,
1583 )]
1584 #[cfg_attr(any(test, feature = "proptest"), derive(proptest_derive::Arbitrary))]
1585 pub struct #struct_name;
1586
1587 #trait_impl
1588 #display_impl
1589
1590 #func
1591 }
1592 };
1593
1594 Ok(result)
1595}