1use darling::FromMeta;
11use proc_macro2::{Ident, TokenStream};
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::{Expr, Lifetime};
15
16#[derive(Debug, Default, darling::FromMeta)]
18pub(crate) struct Modifiers {
19 is_monotone: Option<Expr>,
22 sqlname: Option<String>,
24 preserves_uniqueness: Option<Expr>,
26 inverse: Option<Expr>,
28 negate: Option<Expr>,
30 is_infix_op: Option<Expr>,
33 output_type: Option<syn::Path>,
35 could_error: Option<Expr>,
38 propagates_nulls: Option<Expr>,
40 introduces_nulls: Option<Expr>,
42}
43
44pub fn sqlfunc(
50 attr: TokenStream,
51 item: TokenStream,
52 include_test: bool,
53) -> darling::Result<TokenStream> {
54 let attr_args = darling::ast::NestedMeta::parse_meta_list(attr.clone())?;
55 let modifiers = Modifiers::from_list(&attr_args)?;
56 let func = syn::parse2::<syn::ItemFn>(item.clone())?;
57
58 let tokens = match determine_parameters_arena(&func) {
59 (1, false) => unary_func(&func, modifiers),
60 (1, true) => Err(darling::Error::custom(
61 "Unary functions do not yet support RowArena.",
62 )),
63 (2, arena) => binary_func(&func, modifiers, arena),
64 (other, _) => Err(darling::Error::custom(format!(
65 "Unsupported function: {} parameters",
66 other
67 ))),
68 }?;
69
70 let test = include_test.then(|| generate_test(attr, item, &func.sig.ident));
71
72 Ok(quote! {
73 #tokens
74 #test
75 })
76}
77
78#[cfg(any(feature = "test", test))]
79fn generate_test(attr: TokenStream, item: TokenStream, name: &Ident) -> TokenStream {
80 let attr = attr.to_string();
81 let item = item.to_string();
82 let test_name = Ident::new(&format!("test_{}", name), name.span());
83 let fn_name = name.to_string();
84
85 quote! {
86 #[cfg(test)]
87 #[cfg_attr(miri, ignore)] #[mz_ore::test]
89 fn #test_name() {
90 let (output, input) = mz_expr_derive_impl::test_sqlfunc_str(#attr, #item);
91 insta::assert_snapshot!(#fn_name, output, &input);
92 }
93 }
94}
95
96#[cfg(not(any(feature = "test", test)))]
97fn generate_test(_attr: TokenStream, _item: TokenStream, _name: &Ident) -> TokenStream {
98 quote! {}
99}
100
101fn determine_parameters_arena(func: &syn::ItemFn) -> (usize, bool) {
104 let last_is_arena = func.sig.inputs.last().map_or(false, |last| {
105 if let syn::FnArg::Typed(pat) = last {
106 if let syn::Type::Reference(reference) = &*pat.ty {
107 if let syn::Type::Path(path) = &*reference.elem {
108 return path.path.is_ident("RowArena");
109 }
110 }
111 }
112 false
113 });
114 let parameters = func.sig.inputs.len();
115 if last_is_arena {
116 (parameters - 1, true)
117 } else {
118 (parameters, false)
119 }
120}
121
122fn camel_case(ident: &Ident) -> Ident {
124 let mut result = String::new();
125 let mut capitalize_next = true;
126 for c in ident.to_string().chars() {
127 if c == '_' {
128 capitalize_next = true;
129 } else if capitalize_next {
130 result.push(c.to_ascii_uppercase());
131 capitalize_next = false;
132 } else {
133 result.push(c);
134 }
135 }
136 Ident::new(&result, ident.span())
137}
138
139fn arg_type(arg: &syn::ItemFn, nth: usize) -> Result<syn::Type, syn::Error> {
146 match &arg.sig.inputs[nth] {
147 syn::FnArg::Typed(pat) => {
148 if let syn::Type::Reference(r) = &*pat.ty {
150 if r.lifetime.is_none() {
151 let ty = syn::Type::Reference(syn::TypeReference {
152 lifetime: Some(Lifetime::new("'a", r.span())),
153 ..r.clone()
154 });
155 return Ok(ty);
156 }
157 }
158 Ok((*pat.ty).clone())
159 }
160 _ => Err(syn::Error::new(
161 arg.sig.inputs[nth].span(),
162 "Unsupported argument type",
163 )),
164 }
165}
166
167fn output_type(arg: &syn::ItemFn) -> Result<&syn::Type, syn::Error> {
170 match &arg.sig.output {
171 syn::ReturnType::Type(_, ty) => Ok(&*ty),
172 syn::ReturnType::Default => Err(syn::Error::new(
173 arg.sig.output.span(),
174 "Function needs to return a value",
175 )),
176 }
177}
178
179fn unary_func(func: &syn::ItemFn, modifiers: Modifiers) -> darling::Result<TokenStream> {
181 let fn_name = &func.sig.ident;
182 let struct_name = camel_case(&func.sig.ident);
183 let input_ty = arg_type(func, 0)?;
184 let output_ty = output_type(func)?;
185 let Modifiers {
186 is_monotone,
187 sqlname,
188 preserves_uniqueness,
189 inverse,
190 is_infix_op,
191 output_type,
192 negate,
193 could_error,
194 propagates_nulls,
195 introduces_nulls,
196 } = modifiers;
197
198 if is_infix_op.is_some() {
199 return Err(darling::Error::unknown_field(
200 "is_infix_op not supported for unary functions",
201 ));
202 }
203 if negate.is_some() {
204 return Err(darling::Error::unknown_field(
205 "negate not supported for unary functions",
206 ));
207 }
208 if propagates_nulls.is_some() {
209 return Err(darling::Error::unknown_field(
210 "propagates_nulls not supported for unary functions",
211 ));
212 }
213
214 let preserves_uniqueness_fn = preserves_uniqueness.map(|preserves_uniqueness| {
215 quote! {
216 fn preserves_uniqueness(&self) -> bool {
217 #preserves_uniqueness
218 }
219 }
220 });
221
222 let inverse_fn = inverse.as_ref().map(|inverse| {
223 quote! {
224 fn inverse(&self) -> Option<crate::UnaryFunc> {
225 #inverse
226 }
227 }
228 });
229
230 let is_monotone_fn = is_monotone.map(|is_monotone| {
231 quote! {
232 fn is_monotone(&self) -> bool {
233 #is_monotone
234 }
235 }
236 });
237
238 let name = if let Some(sqlname) = sqlname {
239 quote! {
240 #sqlname
241 }
242 } else {
243 quote! {
244 stringify!(#fn_name)
245 }
246 };
247
248 let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
249 let introduces_nulls_fn = quote! {
250 fn introduces_nulls(&self) -> bool {
251 <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
252 }
253 };
254 let output_type = quote! { <#output_type> };
255 (output_type, Some(introduces_nulls_fn))
256 } else {
257 (quote! { Self::Output }, None)
258 };
259
260 if let Some(introduces_nulls) = introduces_nulls {
261 introduces_nulls_fn = Some(quote! {
262 fn introduces_nulls(&self) -> bool {
263 #introduces_nulls
264 }
265 });
266 }
267
268 let could_error_fn = could_error.map(|could_error| {
269 quote! {
270 fn could_error(&self) -> bool {
271 #could_error
272 }
273 }
274 });
275
276 let result = quote! {
277 #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
278 pub struct #struct_name;
279
280 impl<'a> crate::func::EagerUnaryFunc<'a> for #struct_name {
281 type Input = #input_ty;
282 type Output = #output_ty;
283
284 fn call(&self, a: Self::Input) -> Self::Output {
285 #fn_name(a)
286 }
287
288 fn output_type(&self, input_type: mz_repr::ColumnType) -> mz_repr::ColumnType {
289 use mz_repr::AsColumnType;
290 let output = #output_type::as_column_type();
291 let propagates_nulls = crate::func::EagerUnaryFunc::propagates_nulls(self);
292 let nullable = output.nullable;
293 output.nullable(nullable || (propagates_nulls && input_type.nullable))
296 }
297
298 #could_error_fn
299 #introduces_nulls_fn
300 #inverse_fn
301 #is_monotone_fn
302 #preserves_uniqueness_fn
303 }
304
305 impl std::fmt::Display for #struct_name {
306 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
307 f.write_str(#name)
308 }
309 }
310
311 #func
312 };
313 Ok(result)
314}
315
316fn binary_func(
318 func: &syn::ItemFn,
319 modifiers: Modifiers,
320 arena: bool,
321) -> darling::Result<TokenStream> {
322 let fn_name = &func.sig.ident;
323 let struct_name = camel_case(&func.sig.ident);
324 let input1_ty = arg_type(func, 0)?;
325 let input2_ty = arg_type(func, 1)?;
326 let output_ty = output_type(func)?;
327
328 let Modifiers {
329 is_monotone,
330 sqlname,
331 preserves_uniqueness,
332 inverse,
333 is_infix_op,
334 output_type,
335 negate,
336 could_error,
337 propagates_nulls,
338 introduces_nulls,
339 } = modifiers;
340
341 if preserves_uniqueness.is_some() {
342 return Err(darling::Error::unknown_field(
343 "preserves_uniqueness not supported for binary functions",
344 ));
345 }
346 if inverse.is_some() {
347 return Err(darling::Error::unknown_field(
348 "inverse not supported for binary functions",
349 ));
350 }
351
352 let negate_fn = negate.map(|negate| {
353 quote! {
354 fn negate(&self) -> Option<crate::BinaryFunc> {
355 #negate
356 }
357 }
358 });
359
360 let is_monotone_fn = is_monotone.map(|is_monotone| {
361 quote! {
362 fn is_monotone(&self) -> (bool, bool) {
363 #is_monotone
364 }
365 }
366 });
367
368 let name = if let Some(sqlname) = sqlname {
369 quote! {
370 #sqlname
371 }
372 } else {
373 quote! {
374 stringify!(#fn_name)
375 }
376 };
377
378 let (output_type, mut introduces_nulls_fn) = if let Some(output_type) = output_type {
379 let introduces_nulls_fn = quote! {
380 fn introduces_nulls(&self) -> bool {
381 <#output_type as ::mz_repr::DatumType<'_, ()>>::nullable()
382 }
383 };
384 let output_type = quote! { <#output_type> };
385 (output_type, Some(introduces_nulls_fn))
386 } else {
387 (quote! { Self::Output }, None)
388 };
389
390 if let Some(introduces_nulls) = introduces_nulls {
391 introduces_nulls_fn = Some(quote! {
392 fn introduces_nulls(&self) -> bool {
393 #introduces_nulls
394 }
395 });
396 }
397
398 let arena = if arena {
399 quote! { , temp_storage }
400 } else {
401 quote! {}
402 };
403
404 let could_error_fn = could_error.map(|could_error| {
405 quote! {
406 fn could_error(&self) -> bool {
407 #could_error
408 }
409 }
410 });
411
412 let is_infix_op_fn = is_infix_op.map(|is_infix_op| {
413 quote! {
414 fn is_infix_op(&self) -> bool {
415 #is_infix_op
416 }
417 }
418 });
419
420 let propagates_nulls_fn = propagates_nulls.map(|propagates_nulls| {
421 quote! {
422 fn propagates_nulls(&self) -> bool {
423 #propagates_nulls
424 }
425 }
426 });
427
428 let result = quote! {
429 #[derive(proptest_derive::Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Hash, mz_lowertest::MzReflect)]
430 pub struct #struct_name;
431
432 impl<'a> crate::func::binary::EagerBinaryFunc<'a> for #struct_name {
433 type Input1 = #input1_ty;
434 type Input2 = #input2_ty;
435 type Output = #output_ty;
436
437 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a mz_repr::RowArena) -> Self::Output {
438 #fn_name(a, b #arena)
439 }
440
441 fn output_type(&self, input_type_a: mz_repr::ColumnType, input_type_b: mz_repr::ColumnType) -> mz_repr::ColumnType {
442 use mz_repr::AsColumnType;
443 let output = #output_type::as_column_type();
444 let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(self);
445 let nullable = output.nullable;
446 output.nullable(nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)))
449 }
450
451 #could_error_fn
452 #introduces_nulls_fn
453 #is_infix_op_fn
454 #is_monotone_fn
455 #negate_fn
456 #propagates_nulls_fn
457 }
458
459 impl std::fmt::Display for #struct_name {
460 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
461 f.write_str(#name)
462 }
463 }
464
465 #func
466
467 };
468 Ok(result)
469}