mz_walkabout/
ir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Intermediate representation (IR) for codegen.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::iter;
14
15use anyhow::{Result, bail};
16use itertools::Itertools;
17use quote::ToTokens;
18
19/// The intermediate representation.
20pub struct Ir {
21    /// The items in the IR.
22    pub items: BTreeMap<String, Item>,
23    /// The generic parameters that appear throughout the IR.
24    ///
25    /// Walkabout assumes that generic parameters are named consistently
26    /// throughout the types in the IR. This field maps each generic parameter
27    /// to the union of all trait bounds required of that parameter.
28    pub generics: BTreeMap<String, BTreeSet<String>>,
29}
30
31/// An item in the IR.
32#[derive(Debug)]
33pub enum Item {
34    /// A struct item.
35    Struct(Struct),
36    /// An enum item.
37    Enum(Enum),
38    /// An abstract item, introduced via a generic parameter.
39    Abstract,
40}
41
42impl Item {
43    pub fn fields<'a>(&'a self) -> Box<dyn Iterator<Item = &'a Field> + 'a> {
44        match self {
45            Item::Struct(s) => Box::new(s.fields.iter()),
46            Item::Enum(e) => Box::new(e.variants.iter().flat_map(|v| &v.fields)),
47            Item::Abstract => Box::new(iter::empty()),
48        }
49    }
50
51    pub fn generics(&self) -> &[ItemGeneric] {
52        match self {
53            Item::Struct(s) => &s.generics,
54            Item::Enum(e) => &e.generics,
55            Item::Abstract => &[],
56        }
57    }
58}
59
60/// A struct in the IR.
61#[derive(Debug)]
62pub struct Struct {
63    /// The fields of the struct.
64    pub fields: Vec<Field>,
65    /// The generic parameters on the struct.
66    pub generics: Vec<ItemGeneric>,
67}
68
69/// An enum in the IRs.
70#[derive(Debug)]
71pub struct Enum {
72    /// The variants of the enum.
73    pub variants: Vec<Variant>,
74    /// The generic parameters on the enum.
75    pub generics: Vec<ItemGeneric>,
76}
77
78/// A variant of an [`Enum`].
79#[derive(Debug)]
80pub struct Variant {
81    /// The name of the variant.
82    pub name: String,
83    /// The fields of the variant.
84    pub fields: Vec<Field>,
85}
86
87/// A field of a [`Variant`] or [`Struct`].
88#[derive(Debug)]
89pub struct Field {
90    /// The optional name of the field.
91    ///
92    /// If omitted, the field is referred to by its index in its container.
93    pub name: Option<String>,
94    /// The type of the field.
95    pub ty: Type,
96}
97
98/// A generic parameter of an [`Item`].
99#[derive(Debug)]
100pub struct ItemGeneric {
101    /// The name of the generic parameter.
102    pub name: String,
103    /// The trait bounds on the generic parameter.
104    pub bounds: Vec<String>,
105}
106
107/// The type of a [`Field`].
108#[derive(Debug)]
109pub enum Type {
110    /// A primitive Rust type.
111    ///
112    /// Primitive types do not need to be visited.
113    Primitive,
114    /// Abstract type.
115    ///
116    /// Abstract types are visited, but their default visit function does
117    /// nothing.
118    Abstract(String),
119    /// An [`Option`] type..
120    ///
121    /// The value inside the option will need to be visited if the option is
122    /// `Some`.
123    Option(Box<Type>),
124    /// A [`Vec`] type.
125    ///
126    /// Each value in the vector will need to be visited.
127    Vec(Box<Type>),
128    /// A [`Box`] type.
129    ///
130    /// The value inside the box will need to be visited.
131    Box(Box<Type>),
132    /// A type local to the AST.
133    ///
134    /// The value will need to be visited by calling the appropriate `Visit`
135    /// or `VisitMut` trait method on the value.
136    Local(String),
137    /// A BTreeMap type
138    ///
139    /// Each value will need to be visited.
140    Map { key: Box<Type>, value: Box<Type> },
141}
142
143/// Analyzes the provided items and produces an IR.
144///
145/// This is a very, very lightweight semantic analysis phase for Rust code. Our
146/// main goal is to determine the type of each field of a struct or enum
147/// variant, so we know how to visit it. See [`Type`] for details.
148pub(crate) fn analyze(syn_items: &[syn::DeriveInput]) -> Result<Ir> {
149    let mut items = BTreeMap::new();
150    for syn_item in syn_items {
151        let name = syn_item.ident.to_string();
152        let generics = analyze_generics(&syn_item.generics)?;
153        let item = match &syn_item.data {
154            syn::Data::Struct(s) => Item::Struct(Struct {
155                fields: analyze_fields(&s.fields)?,
156                generics,
157            }),
158            syn::Data::Enum(e) => {
159                let mut variants = vec![];
160                for v in &e.variants {
161                    variants.push(Variant {
162                        name: v.ident.to_string(),
163                        fields: analyze_fields(&v.fields)?,
164                    });
165                }
166                Item::Enum(Enum { variants, generics })
167            }
168            syn::Data::Union(_) => bail!("Unable to analyze union: {}", syn_item.ident),
169        };
170        for field in item.fields() {
171            let mut field_ty = &field.ty;
172            while let Type::Box(ty) | Type::Vec(ty) | Type::Option(ty) = field_ty {
173                field_ty = ty;
174            }
175            if let Type::Abstract(name) = field_ty {
176                items.insert(name.clone(), Item::Abstract);
177            }
178        }
179        items.insert(name, item);
180    }
181
182    let mut generics = BTreeMap::<_, BTreeSet<String>>::new();
183    for item in items.values() {
184        for ig in item.generics() {
185            generics
186                .entry(ig.name.clone())
187                .or_default()
188                .extend(ig.bounds.clone());
189        }
190    }
191
192    for item in items.values() {
193        validate_fields(&items, item.fields())?
194    }
195
196    Ok(Ir { items, generics })
197}
198
199fn validate_fields<'a, I>(items: &BTreeMap<String, Item>, fields: I) -> Result<()>
200where
201    I: IntoIterator<Item = &'a Field>,
202{
203    for f in fields {
204        match &f.ty {
205            Type::Local(s) if !items.contains_key(s) => {
206                bail!(
207                    "Unable to analyze non built-in type that is not defined in input: {}",
208                    s
209                );
210            }
211            _ => (),
212        }
213    }
214    Ok(())
215}
216
217fn analyze_fields(fields: &syn::Fields) -> Result<Vec<Field>> {
218    fields
219        .iter()
220        .map(|f| {
221            Ok(Field {
222                name: f.ident.as_ref().map(|id| id.to_string()),
223                ty: analyze_type(&f.ty)?,
224            })
225        })
226        .collect()
227}
228
229fn analyze_generics(generics: &syn::Generics) -> Result<Vec<ItemGeneric>> {
230    let mut out = vec![];
231    for g in generics.params.iter() {
232        match g {
233            syn::GenericParam::Type(syn::TypeParam { ident, bounds, .. }) => {
234                let name = ident.to_string();
235                let bounds = analyze_generic_bounds(bounds)?;
236                // Generic parameter names that end in '2' conflict with the
237                // folder's name generation.
238                if name.ends_with('2') {
239                    bail!(
240                        "Generic parameters whose name ends in '2' conflict with folder's naming scheme: {}",
241                        name
242                    );
243                }
244                out.push(ItemGeneric { name, bounds });
245            }
246            _ => {
247                bail!(
248                    "Unable to analyze non-type generic parameter: {}",
249                    g.to_token_stream()
250                )
251            }
252        }
253    }
254    Ok(out)
255}
256
257fn analyze_generic_bounds<'a, I>(bounds: I) -> Result<Vec<String>>
258where
259    I: IntoIterator<Item = &'a syn::TypeParamBound>,
260{
261    let mut out = vec![];
262    for b in bounds {
263        match b {
264            syn::TypeParamBound::Trait(t) if t.path.segments.len() != 1 => {
265                bail!(
266                    "Unable to analyze trait bound with more than one path segment: {}",
267                    b.to_token_stream()
268                )
269            }
270            syn::TypeParamBound::Trait(t) => out.push(t.path.segments[0].ident.to_string()),
271            _ => bail!("Unable to analyze non-trait bound: {}", b.to_token_stream()),
272        }
273    }
274    Ok(out)
275}
276
277fn analyze_type(ty: &syn::Type) -> Result<Type> {
278    match ty {
279        syn::Type::Path(syn::TypePath { qself: None, path }) => match path.segments.len() {
280            2 => {
281                let name = path.segments.iter().map(|s| s.ident.to_string()).join("::");
282                Ok(Type::Abstract(name))
283            }
284            1 => {
285                let segment = path.segments.last().unwrap();
286                let segment_name = segment.ident.to_string();
287
288                let container = |construct_ty: fn(Box<Type>) -> Type| match &segment.arguments {
289                    syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
290                        match args.args.last().unwrap() {
291                            syn::GenericArgument::Type(ty) => {
292                                let inner = Box::new(analyze_type(ty)?);
293                                Ok(construct_ty(inner))
294                            }
295                            _ => bail!(
296                                "Container type argument is not a basic (i.e., non-lifetime, non-constraint) type argument: {}",
297                                ty.into_token_stream()
298                            ),
299                        }
300                    }
301                    syn::PathArguments::AngleBracketed(_) => bail!(
302                        "Container type does not have exactly one type argument: {}",
303                        ty.into_token_stream()
304                    ),
305                    syn::PathArguments::Parenthesized(_) => bail!(
306                        "Container type has unexpected parenthesized type arguments: {}",
307                        ty.into_token_stream()
308                    ),
309                    syn::PathArguments::None => bail!(
310                        "Container type is missing type argument: {}",
311                        ty.into_token_stream()
312                    ),
313                };
314
315                match &*segment_name {
316                    "bool" | "usize" | "u8" | "u16" | "u32" | "u64" | "isize" | "i8" | "i16"
317                    | "i32" | "i64" | "f32" | "f64" | "char" | "String" | "PathBuf" => {
318                        match segment.arguments {
319                            syn::PathArguments::None => Ok(Type::Primitive),
320                            _ => bail!(
321                                "Primitive type had unexpected arguments: {}",
322                                ty.into_token_stream()
323                            ),
324                        }
325                    }
326                    "Vec" => container(Type::Vec),
327                    "Option" => container(Type::Option),
328                    "Box" => container(Type::Box),
329                    "BTreeMap" => match &segment.arguments {
330                        syn::PathArguments::None => bail!("Map type missing arguments"),
331                        syn::PathArguments::AngleBracketed(args) if args.args.len() == 2 => {
332                            let key = match &args.args[0] {
333                                syn::GenericArgument::Type(t) => t,
334                                _ => bail!("Invalid argument to map container, should be a Type"),
335                            };
336                            let value = match &args.args[1] {
337                                syn::GenericArgument::Type(t) => t,
338                                _ => bail!("Invalid argument to map container, should be a Type"),
339                            };
340                            Ok(Type::Map {
341                                key: Box::new(analyze_type(key)?),
342                                value: Box::new(analyze_type(value)?),
343                            })
344                        }
345                        &syn::PathArguments::AngleBracketed(_) => {
346                            bail!("wrong type of arguments for map container")
347                        }
348                        syn::PathArguments::Parenthesized(_) => {
349                            bail!("wrong type of arguments for map container")
350                        }
351                    },
352                    _ => Ok(Type::Local(segment_name)),
353                }
354            }
355            _ => {
356                bail!(
357                    "Unable to analyze type path with more than two components: '{}'",
358                    path.into_token_stream()
359                )
360            }
361        },
362        _ => bail!(
363            "Unable to analyze non-struct, non-enum type: {}",
364            ty.into_token_stream()
365        ),
366    }
367}