mz_lowertest_derive/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Macros needed by the `mz_lowertest` crate.
11//!
12//! TODO: eliminate macros in favor of using `walkabout`?
13
14use proc_macro::TokenStream;
15use quote::{ToTokens, quote};
16use syn::{Data, DeriveInput, Fields, parse};
17
18/// Types defined outside of Materialize used to build test objects.
19const EXTERNAL_TYPES: &[&str] = &["String", "FixedOffset", "Tz", "NaiveDateTime", "Regex"];
20const SUPPORTED_ANGLE_TYPES: &[&str] = &["Vec", "Box", "Option"];
21
22/// Macro generating an implementation for the trait MzReflect
23#[proc_macro_derive(MzReflect, attributes(mzreflect))]
24pub fn mzreflect_derive(input: TokenStream) -> TokenStream {
25    // The intended trait implementation is
26    // ```
27    // impl MzReflect for #name {
28    //    /// Adds the information required to create an object of this type
29    //    /// to `enum_dict` if it is an enum and to `struct_dict` if it is a
30    //    /// struct.
31    //    fn add_to_reflected_type_info(
32    //        rti: &mut mz_lowertest::ReflectedTypeInfo
33    //    )
34    //    {
35    //       // if the object is an enum
36    //       if rti.enum_dict.contains_key(#name) { return; }
37    //       use std::collections::BTreeMap;
38    //       let mut result = BTreeMap::new();
39    //       // repeat line below for all variants
40    //       result.insert(variant_name, (<field_names>, <field_types>));
41    //       rti.enum_dist.insert(<enum_name>, result);
42    //
43    //       // if the object is a struct
44    //       if rti.struct_dict.contains_key(#name) { return ; }
45    //       rti.struct_dict.insert(#name, (<field_names>, <field_types>));
46    //
47    //       // for all object types, repeat line below for each field type
48    //       // that should be recursively added to the reflected type info
49    //       <field_type>::add_reflect_type_info(enum_dict, struct_dict);
50    //    }
51    // }
52    // ```
53    let ast: DeriveInput = parse(input).unwrap();
54
55    let object_name = &ast.ident;
56    let object_name_as_string = object_name.to_string();
57    let mut referenced_types = Vec::new();
58    let add_object_info = if let Data::Enum(enumdata) = &ast.data {
59        let variants = enumdata
60            .variants
61            .iter()
62            .map(|v| {
63                let variant_name = v.ident.to_string();
64                let (names, types_as_string, mut types_as_syn) = get_fields_names_types(&v.fields);
65                referenced_types.append(&mut types_as_syn);
66                quote! {
67                    result.insert(#variant_name, (vec![#(#names),*], vec![#(#types_as_string),*]));
68                }
69            })
70            .collect::<Vec<_>>();
71        quote! {
72            if rti.enum_dict.contains_key(#object_name_as_string) { return; }
73            use std::collections::BTreeMap;
74            let mut result = BTreeMap::new();
75            #(#variants)*
76            rti.enum_dict.insert(#object_name_as_string, result);
77        }
78    } else if let Data::Struct(structdata) = &ast.data {
79        let (names, types_as_string, mut types_as_syn) = get_fields_names_types(&structdata.fields);
80        referenced_types.append(&mut types_as_syn);
81        quote! {
82            if rti.struct_dict.contains_key(#object_name_as_string) { return; }
83            rti.struct_dict.insert(#object_name_as_string,
84                (vec![#(#names),*], vec![#(#types_as_string),*]));
85        }
86    } else {
87        unreachable!("Not a struct or enum")
88    };
89
90    let referenced_types = referenced_types
91        .into_iter()
92        .flat_map(extract_reflected_type)
93        .map(|typ| quote! { #typ::add_to_reflected_type_info(rti); })
94        .collect::<Vec<_>>();
95
96    let generated = quote! {
97      impl mz_lowertest::MzReflect for #object_name {
98        fn add_to_reflected_type_info(
99            rti: &mut mz_lowertest::ReflectedTypeInfo
100        )
101        {
102           #add_object_info
103           #(#referenced_types)*
104        }
105      }
106    };
107    generated.into()
108}
109
110/* #region Helper methods */
111
112/// Gets the names and the types of the fields of an enum variant or struct.
113///
114/// The result has three parts:
115/// 1. The names of the fields. If the fields are unnamed, this is empty.
116/// 2. The types of the fields as strings.
117/// 3. The types of the fields as [syn::Type]
118///
119/// Fields with the attribute `#[mzreflect(ignore)]` are not returned.
120fn get_fields_names_types(f: &syn::Fields) -> (Vec<String>, Vec<String>, Vec<&syn::Type>) {
121    match f {
122        Fields::Named(named_fields) => {
123            let (names, types): (Vec<_>, Vec<_>) = named_fields
124                .named
125                .iter()
126                .flat_map(get_field_name_type)
127                .unzip();
128            let (types_as_string, types_as_syn) = types.into_iter().unzip();
129            (names, types_as_string, types_as_syn)
130        }
131        Fields::Unnamed(unnamed_fields) => {
132            let (types_as_string, types_as_syn): (Vec<_>, Vec<_>) = unnamed_fields
133                .unnamed
134                .iter()
135                .flat_map(get_field_name_type)
136                .map(|(_, (type_as_string, type_as_syn))| (type_as_string, type_as_syn))
137                .unzip();
138            (Vec::new(), types_as_string, types_as_syn)
139        }
140        Fields::Unit => (Vec::new(), Vec::new(), Vec::new()),
141    }
142}
143
144/// Gets the name and the type of a field of an enum variant or struct.
145///
146/// The result has three parts:
147/// 1. The name of the field. If the field is unnamed, this is empty.
148/// 2. The type of the field as a string.
149/// 3. The type of the field as [syn::Type].
150///
151/// Returns None if the field has the attribute `#[mzreflect(ignore)]`.
152fn get_field_name_type(f: &syn::Field) -> Option<(String, (String, &syn::Type))> {
153    for attr in f.attrs.iter() {
154        if let Ok(syn::Meta::List(meta_list)) = attr.parse_meta() {
155            if meta_list.path.segments.last().unwrap().ident == "mzreflect" {
156                for nested_meta in meta_list.nested.iter() {
157                    if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested_meta {
158                        if path.segments.last().unwrap().ident == "ignore" {
159                            return None;
160                        }
161                    }
162                }
163            }
164        }
165    }
166    let name = if let Some(name) = f.ident.as_ref() {
167        name.to_string()
168    } else {
169        "".to_string()
170    };
171    Some((name, (get_type_as_string(&f.ty), &f.ty)))
172}
173
174/// Gets the type name from the [`syn::Type`] object
175fn get_type_as_string(t: &syn::Type) -> String {
176    // convert type back into a token stream and then into a string
177    let mut token_stream = proc_macro2::TokenStream::new();
178    t.to_tokens(&mut token_stream);
179    token_stream.to_string()
180}
181
182/// If `t` is a supported type, extracts from `t` types defined in a
183/// Materialize package.
184///
185/// Returns an empty vector if `t` is of an unsupported type.
186///
187/// Supported types are:
188/// A plain path type A -> extracts A
189/// `Box<A>`, `Vec<A>`, `Option<A>`, `[A]` -> extracts A
190/// Tuple (A, (B, C)) -> extracts A, B, C.
191/// Remove A, B, C from expected results if they are primitive types or listed
192/// in [EXTERNAL_TYPES].
193fn extract_reflected_type(t: &syn::Type) -> Vec<&syn::Type> {
194    match t {
195        syn::Type::Group(tg) => {
196            return extract_reflected_type(&tg.elem);
197        }
198        syn::Type::Path(tp) => {
199            let last_segment = tp.path.segments.last().unwrap();
200            let type_name = last_segment.ident.to_string();
201            match &last_segment.arguments {
202                syn::PathArguments::None => {
203                    if EXTERNAL_TYPES.contains(&&type_name[..])
204                        || type_name.starts_with(|c: char| c.is_lowercase())
205                    {
206                        // Ignore primitive types and types
207                        return Vec::new();
208                    } else {
209                        return vec![t];
210                    }
211                }
212                syn::PathArguments::AngleBracketed(args) => {
213                    if SUPPORTED_ANGLE_TYPES.contains(&&type_name[..]) {
214                        return args
215                            .args
216                            .iter()
217                            .flat_map(|arg| {
218                                if let syn::GenericArgument::Type(typ) = arg {
219                                    extract_reflected_type(typ)
220                                } else {
221                                    Vec::new()
222                                }
223                            })
224                            .collect::<Vec<_>>();
225                    }
226                }
227                _ => {}
228            }
229        }
230        syn::Type::Tuple(tt) => {
231            return tt
232                .elems
233                .iter()
234                .flat_map(extract_reflected_type)
235                .collect::<Vec<_>>();
236        }
237        syn::Type::Slice(ts) => {
238            return extract_reflected_type(&ts.elem);
239        }
240        _ => {}
241    }
242    Vec::new()
243}
244
245/* #endregion */