1use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19 is_monotone: Option<Expr>,
22 sqlname: Option<SqlName>,
24 preserves_uniqueness: Option<Expr>,
26 inverse: Option<Expr>,
28 negate: Option<Expr>,
30 is_infix_op: Option<Expr>,
33 output_type: Option<syn::Path>,
35 output_type_expr: Option<Expr>,
37 could_error: Option<Expr>,
40 propagates_nulls: Option<Expr>,
42 introduces_nulls: Option<Expr>,
44 test: Option<bool>,
46}
47
48#[derive(Debug)]
51enum SqlName {
52 Literal(syn::Lit),
54 Macro(syn::ExprMacro),
56}
57
58impl quote::ToTokens for SqlName {
59 fn to_tokens(&self, tokens: &mut TokenStream) {
60 let name = match self {
61 SqlName::Literal(lit) => quote! { #lit },
62 SqlName::Macro(mac) => quote! { #mac },
63 };
64 tokens.extend(name);
65 }
66}
67
68impl darling::FromMeta for SqlName {
69 fn from_value(value: &Lit) -> darling::Result<Self> {
70 Ok(Self::Literal(value.clone()))
71 }
72 fn from_expr(expr: &Expr) -> darling::Result<Self> {
73 match expr {
74 Expr::Lit(lit) => Self::from_value(&lit.lit),
75 Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
76 Expr::Group(mac) => Self::from_expr(&mac.expr),
79 _ => Err(darling::Error::unexpected_expr_type(expr)),
80 }
81 }
82}
83
84pub fn sqlfunc(
90 attr: TokenStream,
91 item: TokenStream,
92 include_test: bool,
93) -> darling::Result<TokenStream> {
94 let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
95 let modifiers = Modifiers::from_list(&attr_args).unwrap();
96 let generate_tests = modifiers.test.unwrap_or(false);
97 let func = syn::parse2::<syn::ItemFn>(item.clone())?;
98
99 let tokens = match determine_parameters_arena(&func) {
100 (1, false) => unary_func(&func, modifiers),
101 (1, true) => Err(darling::Error::custom(
102 "Unary functions do not yet support RowArena.",
103 )),
104 (2, arena) => binary_func(&func, modifiers, arena),
105 (other, _) => Err(darling::Error::custom(format!(
106 "Unsupported function: {} parameters",
107 other
108 ))),
109 }?;
110
111 let test = (generate_tests && include_test).then(|| generate_test(attr, item, &func.sig.ident));
112
113 Ok(quote! {
114 #tokens
115 #test
116 })
117}
118
119#[cfg(any(feature = "test", test))]
120fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
121 let attr = attr.to_string();
122 let item = item.to_string();
123 let test_name = Ident::new(&format!("test_{}", name), name.span());
124 let fn_name = name.to_string();
125
126 quote! {
127 #[cfg(test)]
128 #[cfg_attr(miri, ignore)] #[mz_ore::test]
130 fn #test_name() {
131 let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
132 insta::assert_snapshot!(#fn_name, output, &input);
133 }
134 }
135}
136
137#[cfg(not(any(feature = "test", test)))]
138fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
139 quote! {}
140}
141
142fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
145 let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
146 if let syn::FnArg::Typed(pat) = last {
147 if let syn::Type::Reference(reference) = &*pat.ty {
148 if let syn::Type::Path(path) = &*reference.elem {
149 return path.path.is_ident("RowArena");
150 }
151 }
152 }
153 false
154 });
155 let parameters = func.sig.inputs.len();
156 if last_is_arena {
157 (parameters - 1, true)
158 } else {
159 (parameters, false)
160 }
161}
162
163fn camel_case(ident: &Ident) -> Ident {
165 let mut result = String::new();
166 let mut capitalize_next = true;
167 for c in ident.to_string().chars() {
168 if c == '_' {
169 capitalize_next = true;
170 } else if capitalize_next {
171 result.push(c.to_ascii_uppercase());
172 capitalize_next = false;
173 } else {
174 result.push(c);
175 }
176 }
177 Ident::new(&result, ident.span())
178}
179
180fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
187 match &arg.sig.inputs[nth] {
188 syn::FnArg::Typed(pat) => {
189 if let syn::Type::Reference(r) = &*pat.ty {
191 if r.lifetime.is_none() {
192 let ty = syn::Type::Reference(syn::TypeReference {
193 lifetime: Some(Lifetime::new("'a", r.span())),
194 ..r.clone()
195 });
196 return Ok(ty);
197 }
198 }
199 Ok((*pat.ty).clone())
200 }
201 _ => Err(syn::Error::new(
202 arg.sig.inputs[nth].span(),
203 "Unsupported argument type",
204 )),
205 }
206}
207
208fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
211 match &arg.sig.output {
212 syn::ReturnType::Type(_, ty) => Ok(&*ty),
213 syn::ReturnType::Default => Err(syn::Error::new(
214 arg.sig.output.span(),
215 "Function needs to return a value",
216 )),
217 }
218}
219
220fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
222 let fn_name = &func.sig.ident;
223 let struct_name = camel_case(&func.sig.ident);
224 let input_ty = arg_type(func, 0)?;
225 let output_ty = output_type(func)?;
226 let Modifiers {
227 is_monotone,
228 sqlname,
229 preserves_uniqueness,
230 inverse,
231 is_infix_op,
232 output_type,
233 output_type_expr,
234 negate,
235 could_error,
236 propagates_nulls,
237 introduces_nulls,
238 test: _,
239 } = modifiers;
240
241 if is_infix_op.is_some() {
242 return Err(darling::Error::unknown_field(
243 "is_infix_op not supported for unary functions",
244 ));
245 }
246 if output_type.is_some() && output_type_expr.is_some() {
247 return Err(darling::Error::unknown_field(
248 "output_type and output_type_expr cannot be used together",
249 ));
250 }
251 if output_type_expr.is_some() && introduces_nulls.is_none() {
252 return Err(darling::Error::unknown_field(
253 "output_type_expr requires introduces_nulls",
254 ));
255 }
256 if negate.is_some() {
257 return Err(darling::Error::unknown_field(
258 "negate not supported for unary functions",
259 ));
260 }
261 if propagates_nulls.is_some() {
262 return Err(darling::Error::unknown_field(
263 "propagates_nulls not supported for unary functions",
264 ));
265 }
266
267 let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
268 quote! {
269 fn preserves_uniqueness(&self) -> bool {
270 #preserves_uniqueness
271 }
272 }
273 });
274
275 let inverse_fn = inverse.as_ref().map(|inverse| {
276 quote! {
277 fn inverse(&self) -> Option<crate::UnaryFunc> {
278 #inverse
279 }
280 }
281 });
282
283 let is_monotone_fn = is_monotone.map(|is_monotone| {
284 quote! {
285 fn is_monotone(&self) -> bool {
286 #is_monotone
287 }
288 }
289 });
290
291 let name = sqlname
292 .as_ref()
293 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
294
295 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
296 let introduces_nulls_fn = quote! {
297 fn introduces_nulls(&self) -> bool {
298 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
299 }
300 };
301 let output_type = quote! { <#output_type>::as_column_type() };
302 (output_type, Some(introduces_nulls_fn))
303 } else {
304 (quote! { Self::Output::as_column_type() }, None)
305 };
306
307 if let Some(output_type_expr) = output_type_expr {
308 output_type = quote! { #output_type_expr };
309 }
310
311 if let Some(introduces_nulls) = introduces_nulls {
312 introduces_nulls_fn = Some(quote! {
313 fn introduces_nulls(&self) -> bool {
314 #introduces_nulls
315 }
316 });
317 }
318
319 let could_error_fn = could_error.map(|could_error| {
320 quote! {
321 fn could_error(&self) -> bool {
322 #could_error
323 }
324 }
325 });
326
327 let result = quote! {
328 #[derive(
329 proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
330 Debug, Eq, PartialEq, serde::Serialize,
331 serde::Deserialize, Hash, mz_lowertest::MzReflect,
332 )]
333 pub struct #struct_name;
334
335 impl crate::func::EagerUnaryFunc for #struct_name {
336 type Input<'a> = #input_ty;
337 type Output<'a> = #output_ty;
338
339 fn call<'a>(&self, a: Self::Input<'a>) -> Self::Output<'a> {
340 #fn_name(a)
341 }
342
343 fn output_type(&self, input_type: mz_repr::SqlColumnType) -> mz_repr::SqlColumnType {
344 use mz_repr::AsColumnType;
345 let output = #output_type;
346 let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
347 let nullable = output.nullable;
348 output.nullable(nullable || (propagates_nulls && input_type.nullable))
351 }
352
353 #could_error_fn
354 #introduces_nulls_fn
355 #inverse_fn
356 #is_monotone_fn
357 #preserves_uniqueness_fn
358 }
359
360 impl std::fmt::Display for #struct_name {
361 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
362 f.write_str(#name)
363 }
364 }
365
366 #func
367 };
368 Ok(result)
369}
370
371fn binary_func(
373 func: &syn::ItemFn,
374 modifiers: Modifiers,
375 arena: bool,
376) -> darling::Result<TokenStream> {
377 let fn_name = &func.sig.ident;
378 let struct_name = camel_case(&func.sig.ident);
379 let input1_ty = arg_type(func, 0)?;
380 let input2_ty = arg_type(func, 1)?;
381 let output_ty = output_type(func)?;
382
383 let Modifiers {
384 is_monotone,
385 sqlname,
386 preserves_uniqueness,
387 inverse,
388 is_infix_op,
389 output_type,
390 output_type_expr,
391 negate,
392 could_error,
393 propagates_nulls,
394 introduces_nulls,
395 test: _,
396 } = modifiers;
397
398 if preserves_uniqueness.is_some() {
399 return Err(darling::Error::unknown_field(
400 "preserves_uniqueness not supported for binary functions",
401 ));
402 }
403 if inverse.is_some() {
404 return Err(darling::Error::unknown_field(
405 "inverse not supported for binary functions",
406 ));
407 }
408 if output_type.is_some() && output_type_expr.is_some() {
409 return Err(darling::Error::unknown_field(
410 "output_type and output_type_expr cannot be used together",
411 ));
412 }
413 if output_type_expr.is_some() && introduces_nulls.is_none() {
414 return Err(darling::Error::unknown_field(
415 "output_type_expr requires introduces_nulls",
416 ));
417 }
418
419 let negate_fn = negate.map(|negate| {
420 quote! {
421 fn negate(&self) -> Option<crate::BinaryFunc> {
422 #negate
423 }
424 }
425 });
426
427 let is_monotone_fn = is_monotone.map(|is_monotone| {
428 quote! {
429 fn is_monotone(&self) -> (bool, bool) {
430 #is_monotone
431 }
432 }
433 });
434
435 let name = sqlname
436 .as_ref()
437 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
438
439 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
440 let introduces_nulls_fn = quote! {
441 fn introduces_nulls(&self) -> bool {
442 <#output_type as ::mz_repr::OutputDatumType<'_, ()>>::nullable()
443 }
444 };
445 let output_type = quote! { <#output_type>::as_column_type() };
446 (output_type, Some(introduces_nulls_fn))
447 } else {
448 (quote! { Self::Output::as_column_type() }, None)
449 };
450
451 if let Some(output_type_expr) = output_type_expr {
452 output_type = quote! { #output_type_expr };
453 }
454
455 if let Some(introduces_nulls) = introduces_nulls {
456 introduces_nulls_fn = Some(quote! {
457 fn introduces_nulls(&self) -> bool {
458 #introduces_nulls
459 }
460 });
461 }
462
463 let arena = if arena {
464 quote! { , temp_storage }
465 } else {
466 quote! {}
467 };
468
469 let could_error_fn = could_error.map(|could_error| {
470 quote! {
471 fn could_error(&self) -> bool {
472 #could_error
473 }
474 }
475 });
476
477 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
478 quote! {
479 fn is_infix_op(&self) -> bool {
480 #is_infix_op
481 }
482 }
483 });
484
485 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
486 quote! {
487 fn propagates_nulls(&self) -> bool {
488 #propagates_nulls
489 }
490 }
491 });
492
493 let result = quote! {
494 #[derive(
495 proptest_derive::Arbitrary, Ord, PartialOrd, Clone,
496 Debug, Eq, PartialEq, serde::Serialize,
497 serde::Deserialize, Hash, mz_lowertest::MzReflect,
498 )]
499 pub struct #struct_name;
500
501 impl crate::func::binary::EagerBinaryFunc for #struct_name {
502 type Input<'a> = (#input1_ty, #input2_ty);
503 type Output<'a> = #output_ty;
504
505 fn call<'a>(
506 &self,
507 (a, b): Self::Input<'a>,
508 temp_storage: &'a mz_repr::RowArena
509 ) -> Self::Output<'a> {
510 #fn_name(a, b #arena)
511 }
512
513 fn output_type(
514 &self,
515 input_types: &[mz_repr::SqlColumnType],
516 ) -> mz_repr::SqlColumnType {
517 use mz_repr::AsColumnType;
518 let output = #output_type;
519 let propagates_nulls =
520 crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
521 let nullable = output.nullable;
522 let inputs_nullable = input_types.iter().any(|it| it.nullable);
526 let is_null = nullable || (propagates_nulls && inputs_nullable);
527 output.nullable(is_null)
528 }
529
530 #could_error_fn
531 #introduces_nulls_fn
532 #is_infix_op_fn
533 #is_monotone_fn
534 #negate_fn
535 #propagates_nulls_fn
536 }
537
538 impl std::fmt::Display for #struct_name {
539 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
540 f.write_str(#name)
541 }
542 }
543
544 #func
545
546 };
547 Ok(result)
548}