mz_lowertest_derive/lib.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Macros needed by the `mz_lowertest` crate.
11//!
12//! TODO: eliminate macros in favor of using `walkabout`?
13
14use proc_macro::TokenStream;
15use quote::{ToTokens, quote};
16use syn::{Data, DeriveInput, Fields, parse};
17
18/// Types defined outside of Materialize used to build test objects.
19const EXTERNAL_TYPES: &[&str] = &["String", "FixedOffset", "Tz", "NaiveDateTime", "Regex"];
20const SUPPORTED_ANGLE_TYPES: &[&str] = &["Vec", "Box", "Option"];
21
22/// Macro generating an implementation for the trait MzReflect
23#[proc_macro_derive(MzReflect, attributes(mzreflect))]
24pub fn mzreflect_derive(input: TokenStream) -> TokenStream {
25 // The intended trait implementation is
26 // ```
27 // impl MzReflect for #name {
28 // /// Adds the information required to create an object of this type
29 // /// to `enum_dict` if it is an enum and to `struct_dict` if it is a
30 // /// struct.
31 // fn add_to_reflected_type_info(
32 // rti: &mut mz_lowertest::ReflectedTypeInfo
33 // )
34 // {
35 // // if the object is an enum
36 // if rti.enum_dict.contains_key(#name) { return; }
37 // use std::collections::BTreeMap;
38 // let mut result = BTreeMap::new();
39 // // repeat line below for all variants
40 // result.insert(variant_name, (<field_names>, <field_types>));
41 // rti.enum_dist.insert(<enum_name>, result);
42 //
43 // // if the object is a struct
44 // if rti.struct_dict.contains_key(#name) { return ; }
45 // rti.struct_dict.insert(#name, (<field_names>, <field_types>));
46 //
47 // // for all object types, repeat line below for each field type
48 // // that should be recursively added to the reflected type info
49 // <field_type>::add_reflect_type_info(enum_dict, struct_dict);
50 // }
51 // }
52 // ```
53 let ast: DeriveInput = parse(input).unwrap();
54
55 let object_name = &ast.ident;
56 let object_name_as_string = object_name.to_string();
57 let mut referenced_types = Vec::new();
58 let add_object_info = if let Data::Enum(enumdata) = &ast.data {
59 let variants = enumdata
60 .variants
61 .iter()
62 .map(|v| {
63 let variant_name = v.ident.to_string();
64 let (names, types_as_string, mut types_as_syn) = get_fields_names_types(&v.fields);
65 referenced_types.append(&mut types_as_syn);
66 quote! {
67 result.insert(#variant_name, (vec![#(#names),*], vec![#(#types_as_string),*]));
68 }
69 })
70 .collect::<Vec<_>>();
71 quote! {
72 if rti.enum_dict.contains_key(#object_name_as_string) { return; }
73 use std::collections::BTreeMap;
74 let mut result = BTreeMap::new();
75 #(#variants)*
76 rti.enum_dict.insert(#object_name_as_string, result);
77 }
78 } else if let Data::Struct(structdata) = &ast.data {
79 let (names, types_as_string, mut types_as_syn) = get_fields_names_types(&structdata.fields);
80 referenced_types.append(&mut types_as_syn);
81 quote! {
82 if rti.struct_dict.contains_key(#object_name_as_string) { return; }
83 rti.struct_dict.insert(#object_name_as_string,
84 (vec![#(#names),*], vec![#(#types_as_string),*]));
85 }
86 } else {
87 unreachable!("Not a struct or enum")
88 };
89
90 let referenced_types = referenced_types
91 .into_iter()
92 .flat_map(extract_reflected_type)
93 .map(|typ| quote! { #typ::add_to_reflected_type_info(rti); })
94 .collect::<Vec<_>>();
95
96 let generated = quote! {
97 impl mz_lowertest::MzReflect for #object_name {
98 fn add_to_reflected_type_info(
99 rti: &mut mz_lowertest::ReflectedTypeInfo
100 )
101 {
102 #add_object_info
103 #(#referenced_types)*
104 }
105 }
106 };
107 generated.into()
108}
109
110/* #region Helper methods */
111
112/// Gets the names and the types of the fields of an enum variant or struct.
113///
114/// The result has three parts:
115/// 1. The names of the fields. If the fields are unnamed, this is empty.
116/// 2. The types of the fields as strings.
117/// 3. The types of the fields as [syn::Type]
118///
119/// Fields with the attribute `#[mzreflect(ignore)]` are not returned.
120fn get_fields_names_types(f: &syn::Fields) -> (Vec<String>, Vec<String>, Vec<&syn::Type>) {
121 match f {
122 Fields::Named(named_fields) => {
123 let (names, types): (Vec<_>, Vec<_>) = named_fields
124 .named
125 .iter()
126 .flat_map(get_field_name_type)
127 .unzip();
128 let (types_as_string, types_as_syn) = types.into_iter().unzip();
129 (names, types_as_string, types_as_syn)
130 }
131 Fields::Unnamed(unnamed_fields) => {
132 let (types_as_string, types_as_syn): (Vec<_>, Vec<_>) = unnamed_fields
133 .unnamed
134 .iter()
135 .flat_map(get_field_name_type)
136 .map(|(_, (type_as_string, type_as_syn))| (type_as_string, type_as_syn))
137 .unzip();
138 (Vec::new(), types_as_string, types_as_syn)
139 }
140 Fields::Unit => (Vec::new(), Vec::new(), Vec::new()),
141 }
142}
143
144/// Gets the name and the type of a field of an enum variant or struct.
145///
146/// The result has three parts:
147/// 1. The name of the field. If the field is unnamed, this is empty.
148/// 2. The type of the field as a string.
149/// 3. The type of the field as [syn::Type].
150///
151/// Returns None if the field has the attribute `#[mzreflect(ignore)]`.
152fn get_field_name_type(f: &syn::Field) -> Option<(String, (String, &syn::Type))> {
153 for attr in f.attrs.iter() {
154 if let Ok(syn::Meta::List(meta_list)) = attr.parse_meta() {
155 if meta_list.path.segments.last().unwrap().ident == "mzreflect" {
156 for nested_meta in meta_list.nested.iter() {
157 if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested_meta {
158 if path.segments.last().unwrap().ident == "ignore" {
159 return None;
160 }
161 }
162 }
163 }
164 }
165 }
166 let name = if let Some(name) = f.ident.as_ref() {
167 name.to_string()
168 } else {
169 "".to_string()
170 };
171 Some((name, (get_type_as_string(&f.ty), &f.ty)))
172}
173
174/// Gets the type name from the [`syn::Type`] object
175fn get_type_as_string(t: &syn::Type) -> String {
176 // convert type back into a token stream and then into a string
177 let mut token_stream = proc_macro2::TokenStream::new();
178 t.to_tokens(&mut token_stream);
179 token_stream.to_string()
180}
181
182/// If `t` is a supported type, extracts from `t` types defined in a
183/// Materialize package.
184///
185/// Returns an empty vector if `t` is of an unsupported type.
186///
187/// Supported types are:
188/// A plain path type A -> extracts A
189/// `Box<A>`, `Vec<A>`, `Option<A>`, `[A]` -> extracts A
190/// Tuple (A, (B, C)) -> extracts A, B, C.
191/// Remove A, B, C from expected results if they are primitive types or listed
192/// in [EXTERNAL_TYPES].
193fn extract_reflected_type(t: &syn::Type) -> Vec<&syn::Type> {
194 match t {
195 syn::Type::Group(tg) => {
196 return extract_reflected_type(&tg.elem);
197 }
198 syn::Type::Path(tp) => {
199 let last_segment = tp.path.segments.last().unwrap();
200 let type_name = last_segment.ident.to_string();
201 match &last_segment.arguments {
202 syn::PathArguments::None => {
203 if EXTERNAL_TYPES.contains(&&type_name[..])
204 || type_name.starts_with(|c: char| c.is_lowercase())
205 {
206 // Ignore primitive types and types
207 return Vec::new();
208 } else {
209 return vec![t];
210 }
211 }
212 syn::PathArguments::AngleBracketed(args) => {
213 if SUPPORTED_ANGLE_TYPES.contains(&&type_name[..]) {
214 return args
215 .args
216 .iter()
217 .flat_map(|arg| {
218 if let syn::GenericArgument::Type(typ) = arg {
219 extract_reflected_type(typ)
220 } else {
221 Vec::new()
222 }
223 })
224 .collect::<Vec<_>>();
225 }
226 }
227 _ => {}
228 }
229 }
230 syn::Type::Tuple(tt) => {
231 return tt
232 .elems
233 .iter()
234 .flat_map(extract_reflected_type)
235 .collect::<Vec<_>>();
236 }
237 syn::Type::Slice(ts) => {
238 return extract_reflected_type(&ts.elem);
239 }
240 _ => {}
241 }
242 Vec::new()
243}
244
245/* #endregion */