1use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime, Lit};
15
16#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19 is_monotone: Option<Expr>,
22 sqlname: Option<SqlName>,
24 preserves_uniqueness: Option<Expr>,
26 inverse: Option<Expr>,
28 negate: Option<Expr>,
30 is_infix_op: Option<Expr>,
33 output_type: Option<syn::Path>,
35 output_type_expr: Option<Expr>,
37 could_error: Option<Expr>,
40 propagates_nulls: Option<Expr>,
42 introduces_nulls: Option<Expr>,
44}
45
46#[derive(Debug)]
49enum SqlName {
50 Literal(syn::Lit),
52 Macro(syn::ExprMacro),
54}
55
56impl quote::ToTokens for SqlName {
57 fn to_tokens(&self, tokens: &mut TokenStream) {
58 let name = match self {
59 SqlName::Literal(lit) => quote! { #lit },
60 SqlName::Macro(mac) => quote! { #mac },
61 };
62 tokens.extend(name);
63 }
64}
65
66impl darling::FromMeta for SqlName {
67 fn from_value(value: &Lit) -> darling::Result<Self> {
68 Ok(Self::Literal(value.clone()))
69 }
70 fn from_expr(expr: &Expr) -> darling::Result<Self> {
71 match expr {
72 Expr::Lit(lit) => Self::from_value(&lit.lit),
73 Expr::Macro(mac) => Ok(Self::Macro(mac.clone())),
74 Expr::Group(mac) => Self::from_expr(&mac.expr),
77 _ => Err(darling::Error::unexpected_expr_type(expr)),
78 }
79 }
80}
81
82pub fn sqlfunc(
88 attr: TokenStream,
89 item: TokenStream,
90 include_test: bool,
91) -> darling::Result<TokenStream> {
92 let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
93 let modifiers = Modifiers::from_list(&attr_args).unwrap();
94 let func = syn::parse2::<syn::ItemFn>(item.clone())?;
95
96 let tokens = match determine_parameters_arena(&func) {
97 (1, false) => unary_func(&func, modifiers),
98 (1, true) => Err(darling::Error::custom(
99 "Unary functions do not yet support RowArena.",
100 )),
101 (2, arena) => binary_func(&func, modifiers, arena),
102 (other, _) => Err(darling::Error::custom(format!(
103 "Unsupported function: {} parameters",
104 other
105 ))),
106 }?;
107
108 let test = include_test.then(|| generate_test(attr, item, &func.sig.ident));
109
110 Ok(quote! {
111 #tokens
112 #test
113 })
114}
115
116#[cfg(any(feature = "test", test))]
117fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
118 let attr = attr.to_string();
119 let item = item.to_string();
120 let test_name = Ident::new(&format!("test_{}", name), name.span());
121 let fn_name = name.to_string();
122
123 quote! {
124 #[cfg(test)]
125 #[cfg_attr(miri, ignore)] #[mz_ore::test]
127 fn #test_name() {
128 let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
129 insta::assert_snapshot!(#fn_name, output, &input);
130 }
131 }
132}
133
134#[cfg(not(any(feature = "test", test)))]
135fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
136 quote! {}
137}
138
139fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
142 let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
143 if let syn::FnArg::Typed(pat) = last {
144 if let syn::Type::Reference(reference) = &*pat.ty {
145 if let syn::Type::Path(path) = &*reference.elem {
146 return path.path.is_ident("RowArena");
147 }
148 }
149 }
150 false
151 });
152 let parameters = func.sig.inputs.len();
153 if last_is_arena {
154 (parameters - 1, true)
155 } else {
156 (parameters, false)
157 }
158}
159
160fn camel_case(ident: &Ident) -> Ident {
162 let mut result = String::new();
163 let mut capitalize_next = true;
164 for c in ident.to_string().chars() {
165 if c == '_' {
166 capitalize_next = true;
167 } else if capitalize_next {
168 result.push(c.to_ascii_uppercase());
169 capitalize_next = false;
170 } else {
171 result.push(c);
172 }
173 }
174 Ident::new(&result, ident.span())
175}
176
177fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
184 match &arg.sig.inputs[nth] {
185 syn::FnArg::Typed(pat) => {
186 if let syn::Type::Reference(r) = &*pat.ty {
188 if r.lifetime.is_none() {
189 let ty = syn::Type::Reference(syn::TypeReference {
190 lifetime: Some(Lifetime::new("'a", r.span())),
191 ..r.clone()
192 });
193 return Ok(ty);
194 }
195 }
196 Ok((*pat.ty).clone())
197 }
198 _ => Err(syn::Error::new(
199 arg.sig.inputs[nth].span(),
200 "Unsupported argument type",
201 )),
202 }
203}
204
205fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
208 match &arg.sig.output {
209 syn::ReturnType::Type(_, ty) => Ok(&*ty),
210 syn::ReturnType::Default => Err(syn::Error::new(
211 arg.sig.output.span(),
212 "Function needs to return a value",
213 )),
214 }
215}
216
217fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
219 let fn_name = &func.sig.ident;
220 let struct_name = camel_case(&func.sig.ident);
221 let input_ty = arg_type(func, 0)?;
222 let output_ty = output_type(func)?;
223 let Modifiers {
224 is_monotone,
225 sqlname,
226 preserves_uniqueness,
227 inverse,
228 is_infix_op,
229 output_type,
230 output_type_expr,
231 negate,
232 could_error,
233 propagates_nulls,
234 introduces_nulls,
235 } = modifiers;
236
237 if is_infix_op.is_some() {
238 return Err(darling::Error::unknown_field(
239 "is_infix_op not supported for unary functions",
240 ));
241 }
242 if output_type_expr.is_some() {
243 return Err(darling::Error::unknown_field(
244 "output_type_expr not supported for unary functions",
245 ));
246 }
247 if negate.is_some() {
248 return Err(darling::Error::unknown_field(
249 "negate not supported for unary functions",
250 ));
251 }
252 if propagates_nulls.is_some() {
253 return Err(darling::Error::unknown_field(
254 "propagates_nulls not supported for unary functions",
255 ));
256 }
257
258 let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
259 quote! {
260 fn preserves_uniqueness(&self) -> bool {
261 #preserves_uniqueness
262 }
263 }
264 });
265
266 let inverse_fn = inverse.as_ref().map(|inverse| {
267 quote! {
268 fn inverse(&self) -> Option<crate::UnaryFunc> {
269 #inverse
270 }
271 }
272 });
273
274 let is_monotone_fn = is_monotone.map(|is_monotone| {
275 quote! {
276 fn is_monotone(&self) -> bool {
277 #is_monotone
278 }
279 }
280 });
281
282 let name = sqlname
283 .as_ref()
284 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
285
286 let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
287 let introduces_nulls_fn = quote! {
288 fn introduces_nulls(&self) -> bool {
289 <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
290 }
291 };
292 let output_type = quote! { <#output_type> };
293 (output_type, Some(introduces_nulls_fn))
294 } else {
295 (quote! { Self::Output }, None)
296 };
297
298 if let Some(introduces_nulls) = introduces_nulls {
299 introduces_nulls_fn = Some(quote! {
300 fn introduces_nulls(&self) -> bool {
301 #introduces_nulls
302 }
303 });
304 }
305
306 let could_error_fn = could_error.map(|could_error| {
307 quote! {
308 fn could_error(&self) -> bool {
309 #could_error
310 }
311 }
312 });
313
314 let result = quote! {
315 #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
316 pub struct #struct_name;
317
318 impl<'a> crate::func::EagerUnaryFunc<'a> for #struct_name {
319 type Input = #input_ty;
320 type Output = #output_ty;
321
322 fn call(&self, a: Self::Input) -> Self::Output {
323 #fn_name(a)
324 }
325
326 fn output_type(&self, input_type: mz_repr::ColumnType) -> mz_repr::ColumnType {
327 use mz_repr::AsColumnType;
328 let output = #output_type::as_column_type();
329 let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
330 let nullable = output.nullable;
331 output.nullable(nullable || (propagates_nulls && input_type.nullable))
334 }
335
336 #could_error_fn
337 #introduces_nulls_fn
338 #inverse_fn
339 #is_monotone_fn
340 #preserves_uniqueness_fn
341 }
342
343 impl std::fmt::Display for #struct_name {
344 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
345 f.write_str(#name)
346 }
347 }
348
349 #func
350 };
351 Ok(result)
352}
353
354fn binary_func(
356 func: &syn::ItemFn,
357 modifiers: Modifiers,
358 arena: bool,
359) -> darling::Result<TokenStream> {
360 let fn_name = &func.sig.ident;
361 let struct_name = camel_case(&func.sig.ident);
362 let input1_ty = arg_type(func, 0)?;
363 let input2_ty = arg_type(func, 1)?;
364 let output_ty = output_type(func)?;
365
366 let Modifiers {
367 is_monotone,
368 sqlname,
369 preserves_uniqueness,
370 inverse,
371 is_infix_op,
372 output_type,
373 output_type_expr,
374 negate,
375 could_error,
376 propagates_nulls,
377 introduces_nulls,
378 } = modifiers;
379
380 if preserves_uniqueness.is_some() {
381 return Err(darling::Error::unknown_field(
382 "preserves_uniqueness not supported for binary functions",
383 ));
384 }
385 if inverse.is_some() {
386 return Err(darling::Error::unknown_field(
387 "inverse not supported for binary functions",
388 ));
389 }
390 if output_type.is_some() && output_type_expr.is_some() {
391 return Err(darling::Error::unknown_field(
392 "output_type and output_type_expr cannot be used together",
393 ));
394 }
395 if output_type_expr.is_some() && introduces_nulls.is_none() {
396 return Err(darling::Error::unknown_field(
397 "output_type_expr requires introduces_nulls",
398 ));
399 }
400
401 let negate_fn = negate.map(|negate| {
402 quote! {
403 fn negate(&self) -> Option<crate::BinaryFunc> {
404 #negate
405 }
406 }
407 });
408
409 let is_monotone_fn = is_monotone.map(|is_monotone| {
410 quote! {
411 fn is_monotone(&self) -> (bool, bool) {
412 #is_monotone
413 }
414 }
415 });
416
417 let name = sqlname
418 .as_ref()
419 .map_or_else(|| quote! { stringify!(#fn_name) }, |name| quote! { #name });
420
421 let (mut output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
422 let introduces_nulls_fn = quote! {
423 fn introduces_nulls(&self) -> bool {
424 <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
425 }
426 };
427 let output_type = quote! { <#output_type>::as_column_type() };
428 (output_type, Some(introduces_nulls_fn))
429 } else {
430 (quote! { Self::Output::as_column_type() }, None)
431 };
432
433 if let Some(output_type_expr) = output_type_expr {
434 output_type = quote! { #output_type_expr };
435 }
436
437 if let Some(introduces_nulls) = introduces_nulls {
438 introduces_nulls_fn = Some(quote! {
439 fn introduces_nulls(&self) -> bool {
440 #introduces_nulls
441 }
442 });
443 }
444
445 let arena = if arena {
446 quote! { , temp_storage }
447 } else {
448 quote! {}
449 };
450
451 let could_error_fn = could_error.map(|could_error| {
452 quote! {
453 fn could_error(&self) -> bool {
454 #could_error
455 }
456 }
457 });
458
459 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
460 quote! {
461 fn is_infix_op(&self) -> bool {
462 #is_infix_op
463 }
464 }
465 });
466
467 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
468 quote! {
469 fn propagates_nulls(&self) -> bool {
470 #propagates_nulls
471 }
472 }
473 });
474
475 let result = quote! {
476 #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
477 pub struct #struct_name;
478
479 impl<'a> crate::func::binary::EagerBinaryFunc<'a> for #struct_name {
480 type Input1 = #input1_ty;
481 type Input2 = #input2_ty;
482 type Output = #output_ty;
483
484 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a mz_repr::RowArena) -> Self::Output {
485 #fn_name(a, b #arena)
486 }
487
488 fn output_type(&self, input_type_a: mz_repr::ColumnType, input_type_b: mz_repr::ColumnType) -> mz_repr::ColumnType {
489 use mz_repr::AsColumnType;
490 let output = #output_type;
491 let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
492 let nullable = output.nullable;
493 output.nullable(nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)))
496 }
497
498 #could_error_fn
499 #introduces_nulls_fn
500 #is_infix_op_fn
501 #is_monotone_fn
502 #negate_fn
503 #propagates_nulls_fn
504 }
505
506 impl std::fmt::Display for #struct_name {
507 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
508 f.write_str(#name)
509 }
510 }
511
512 #func
513
514 };
515 Ok(result)
516}