mz_walkabout/
generated.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Code generation.
11//!
12//! This module processes the IR to generate the `fold`, `visit`, and
13//! `visit_mut` modules.
14
15use std::collections::{BTreeMap, BTreeSet};
16
17use itertools::Itertools;
18use mz_ore_build::codegen::CodegenBuf;
19
20use crate::ir::{Ir, Item, Type};
21
22/// Generates a fold transformer for a mutable AST.
23///
24/// Returns a string of Rust code that should be compiled alongside the module
25/// from which it was generated.
26pub fn gen_fold(ir: &Ir) -> String {
27    gen_fold_root(ir)
28}
29
30/// Generates a visitor for an immutable AST.
31///
32/// Returns a string of Rust code that should be compiled alongside the module
33/// from which it was generated.
34pub fn gen_visit(ir: &Ir) -> String {
35    gen_visit_root(&VisitConfig { mutable: false }, ir)
36}
37
38/// Generates a visitor for a mutable AST.
39///
40/// Returns a string of Rust code that should be compiled alongside the module
41/// from which it was generated.
42pub fn gen_visit_mut(ir: &Ir) -> String {
43    gen_visit_root(&VisitConfig { mutable: true }, ir)
44}
45
46pub fn gen_fold_root(ir: &Ir) -> String {
47    let mut generics = BTreeMap::new();
48    for (name, bounds) in &ir.generics {
49        generics.insert(name.clone(), bounds.clone());
50        generics.insert(format!("{name}2"), bounds.clone());
51    }
52    let trait_generics = trait_generics(&generics);
53    let trait_generics_and_bounds = trait_generics_and_bounds(&generics);
54
55    let mut buf = CodegenBuf::new();
56
57    buf.write_block(
58        format!("pub trait Fold<{trait_generics_and_bounds}>"),
59        |buf| {
60            for (name, item) in &ir.items {
61                match item {
62                    Item::Abstract => {
63                        // The intent is to replace `T::FooBar` with `T2::FooBar`. This
64                        // is a bit gross, but it seems reliable enough, and is so far
65                        // simpler than trying to use a structured type for `name`.
66                        let name2 = name.replacen("::", "2::", 1);
67                        let fn_name = fold_fn_name(name);
68                        buf.writeln(format!("fn {fn_name}(&mut self, node: {name}) -> {name2};"))
69                    }
70                    Item::Struct(_) | Item::Enum(_) => {
71                        let generics = item_generics(item, "");
72                        let generics2 = item_generics(item, "2");
73                        let fn_name = fold_fn_name(name);
74                        buf.write_block(
75                            format!("fn {fn_name}(&mut self, node: {name}{generics}) -> {name}{generics2}"),
76                            |buf| buf.writeln(format!("{fn_name}(self, node)")),
77                        );
78                    }
79                }
80            }
81        },
82    );
83
84    buf.write_block(
85        format!("pub trait FoldNode<{trait_generics_and_bounds}>"),
86        |buf| {
87            buf.writeln("type Folded;");
88            buf.writeln(format!(
89                "fn fold<F: Fold<{trait_generics}>>(self, folder: &mut F) -> Self::Folded;"
90            ));
91        },
92    );
93
94    for (name, item) in &ir.items {
95        if let Item::Abstract = item {
96            continue;
97        }
98        let generics = item_generics(item, "");
99        let generics2 = item_generics(item, "2");
100        let fn_name = fold_fn_name(name);
101        buf.write_block(
102            format!(
103                "impl<{trait_generics_and_bounds}> FoldNode<{trait_generics}> for {name}{generics}"
104            ),
105            |buf| {
106                buf.writeln(format!("type Folded = {name}{generics2};"));
107                buf.write_block(
108                    format!(
109                        "fn fold<F: Fold<{trait_generics}>>(self, folder: &mut F) -> Self::Folded"
110                    ),
111                    |buf| buf.writeln(format!("folder.{fn_name}(self)")),
112                );
113            },
114        );
115
116        buf.writeln("#[allow(clippy::needless_pass_by_ref_mut)]");
117        buf.writeln(format!(
118            "pub fn {fn_name}<F, {trait_generics_and_bounds}>(folder: &mut F, node: {name}{generics}) -> {name}{generics2}"
119        ));
120        buf.writeln("where");
121        buf.writeln(format!("    F: Fold<{trait_generics}> + ?Sized,"));
122        buf.write_block("", |buf| match item {
123            Item::Struct(s) => {
124                buf.write_block(name, |buf| {
125                    for (i, f) in s.fields.iter().enumerate() {
126                        let field_name = match &f.name {
127                            Some(name) => name.clone(),
128                            None => i.to_string(),
129                        };
130                        let binding = format!("node.{field_name}");
131                        buf.start_line();
132                        buf.write(format!("{field_name}: "));
133                        gen_fold_element(buf, &binding, &f.ty);
134                        buf.write(",");
135                        buf.end_line();
136                    }
137                });
138            }
139            Item::Enum(e) => {
140                buf.write_block("match node", |buf| {
141                    for v in &e.variants {
142                        let vname = &v.name;
143                        buf.write_block(format!("{name}::{vname}"), |buf| {
144                            for (i, f) in v.fields.iter().enumerate() {
145                                let name = f.name.clone().unwrap_or_else(|| i.to_string());
146                                buf.writeln(format!("{name}: binding{i},"));
147                            }
148                            buf.restart_block("=>");
149                            buf.write_block(format!("{name}::{vname}"), |buf| {
150                                for (i, f) in v.fields.iter().enumerate() {
151                                    let field_name = match &f.name {
152                                        Some(name) => name.clone(),
153                                        None => i.to_string(),
154                                    };
155                                    let binding = format!("binding{i}");
156                                    buf.start_line();
157                                    buf.write(format!("{field_name}: "));
158                                    gen_fold_element(buf, &binding, &f.ty);
159                                    buf.write(",");
160                                    buf.end_line();
161                                }
162                            });
163                        });
164                    }
165                });
166            }
167            Item::Abstract => (),
168        });
169    }
170
171    buf.into_string()
172}
173
174fn gen_fold_element(buf: &mut CodegenBuf, binding: &str, ty: &Type) {
175    match ty {
176        Type::Primitive => buf.write(binding),
177        Type::Abstract(ty) => {
178            let fn_name = fold_fn_name(ty);
179            buf.write(format!("folder.{fn_name}({binding})"));
180        }
181        Type::Option(ty) => {
182            buf.write(format!("{binding}.map(|v| "));
183            gen_fold_element(buf, "v", ty);
184            buf.write(")")
185        }
186        Type::Vec(ty) => {
187            buf.write(format!("{binding}.into_iter().map(|v| "));
188            gen_fold_element(buf, "v", ty);
189            buf.write(").collect()");
190        }
191        Type::Box(ty) => {
192            buf.write("Box::new(");
193            gen_fold_element(buf, &format!("*{binding}"), ty);
194            buf.write(")");
195        }
196        Type::Local(s) => {
197            let fn_name = fold_fn_name(s);
198            buf.write(format!("folder.{fn_name}({binding})"));
199        }
200        Type::Map { key, value } => {
201            buf.write(format!(
202                "{{ std::collections::BTreeMap::from_iter({binding}.into_iter().map(|(k, v)| {{("
203            ));
204            gen_fold_element(buf, "k", key);
205            buf.write(".to_owned(), ");
206            gen_fold_element(buf, "v", value);
207            buf.write(".to_owned()) }) )}")
208        }
209    }
210}
211
212struct VisitConfig {
213    mutable: bool,
214}
215
216fn gen_visit_root(c: &VisitConfig, ir: &Ir) -> String {
217    let trait_name = if c.mutable { "VisitMut" } else { "Visit" };
218    let fn_name_base = if c.mutable { "visit_mut" } else { "visit" };
219    let muta = if c.mutable { "mut " } else { "" };
220    let trait_generics = trait_generics(&ir.generics);
221    let trait_generics_and_bounds = trait_generics_and_bounds(&ir.generics);
222
223    let mut buf = CodegenBuf::new();
224
225    buf.write_block(
226        format!("pub trait {trait_name}<'ast, {trait_generics_and_bounds}>"),
227        |buf| {
228            for (name, item) in &ir.items {
229                let generics = item_generics(item, "");
230                let fn_name = visit_fn_name(c, name);
231                buf.write_block(
232                    format!("fn {fn_name}(&mut self, node: &'ast {muta}{name}{generics})"),
233                    |buf| buf.writeln(format!("{fn_name}(self, node)")),
234                );
235            }
236        },
237    );
238
239    buf.write_block(format!(
240        "pub trait {trait_name}Node<'ast, {trait_generics_and_bounds}>"
241    ), |buf| buf.writeln(format!(
242        "fn {fn_name_base}<V: {trait_name}<'ast, {trait_generics}>>(&'ast {muta}self, visitor: &mut V);"
243    )));
244
245    for (name, item) in &ir.items {
246        let generics = item_generics(item, "");
247        let fn_name = visit_fn_name(c, name);
248        if !matches!(item, Item::Abstract) {
249            buf.write_block(format!(
250                "impl<'ast, {trait_generics_and_bounds}> {trait_name}Node<'ast, {trait_generics}> for {name}{generics}"
251            ), |buf| {
252                buf.write_block(format!(
253                    "fn {fn_name_base}<V: {trait_name}<'ast, {trait_generics}>>(&'ast {muta}self, visitor: &mut V)"
254                ), |buf| buf.writeln(format!("visitor.{fn_name}(self)")));
255            });
256        }
257        buf.writeln("#[allow(clippy::needless_pass_by_ref_mut)]");
258        buf.writeln(format!(
259            "pub fn {fn_name}<'ast, V, {trait_generics_and_bounds}>(visitor: &mut V, node: &'ast {muta}{name}{generics})"
260        ));
261        buf.writeln("where");
262        buf.writeln(format!(
263            "    V: {trait_name}<'ast, {trait_generics}> + ?Sized,"
264        ));
265        buf.write_block("", |buf| match item {
266            Item::Struct(s) => {
267                for (i, f) in s.fields.iter().enumerate() {
268                    let binding = match &f.name {
269                        Some(name) => format!("&{muta}node.{name}"),
270                        None => format!("&{muta}node.{i}"),
271                    };
272                    gen_visit_element(c, buf, &binding, &f.ty);
273                }
274            }
275            Item::Enum(e) => {
276                buf.write_block("match node", |buf| {
277                    for v in &e.variants {
278                        let vname = &v.name;
279                        buf.write_block(format!("{name}::{vname}"), |buf| {
280                            for (i, f) in v.fields.iter().enumerate() {
281                                let name = f.name.clone().unwrap_or_else(|| i.to_string());
282                                buf.writeln(format!("{name}: binding{i},"));
283                            }
284                            buf.restart_block("=>");
285                            for (i, f) in v.fields.iter().enumerate() {
286                                let binding = format!("binding{i}");
287                                gen_visit_element(c, buf, &binding, &f.ty);
288                            }
289                        });
290                    }
291                });
292            }
293            Item::Abstract => (),
294        });
295    }
296
297    buf.into_string()
298}
299
300fn gen_visit_element(c: &VisitConfig, buf: &mut CodegenBuf, binding: &str, ty: &Type) {
301    match ty {
302        Type::Primitive => (),
303        Type::Abstract(ty) => {
304            let fn_name = visit_fn_name(c, ty);
305            buf.writeln(format!("visitor.{fn_name}({binding});"));
306        }
307        Type::Option(ty) => {
308            buf.write_block(format!("if let Some(v) = {binding}"), |buf| {
309                gen_visit_element(c, buf, "v", ty)
310            });
311        }
312        Type::Vec(ty) => {
313            buf.write_block(format!("for v in {binding}"), |buf| {
314                gen_visit_element(c, buf, "v", ty)
315            });
316        }
317        Type::Box(ty) => {
318            let binding = match c.mutable {
319                true => format!("&mut *{binding}"),
320                false => format!("&*{binding}"),
321            };
322            gen_visit_element(c, buf, &binding, ty);
323        }
324        Type::Local(s) => {
325            let fn_name = visit_fn_name(c, s);
326            buf.writeln(format!("visitor.{fn_name}({binding});"));
327        }
328        Type::Map { value, .. } => {
329            buf.write_block(format!("for (_, value) in {binding}"), |buf| {
330                gen_visit_element(c, buf, "value", value)
331            });
332        }
333    }
334}
335
336fn fold_fn_name(s: &str) -> String {
337    let mut out = String::from("fold");
338    write_fn_name(&mut out, s);
339    out
340}
341
342fn visit_fn_name(c: &VisitConfig, s: &str) -> String {
343    let mut out = String::from("visit");
344    write_fn_name(&mut out, s);
345    if c.mutable {
346        out.push_str("_mut");
347    }
348    out
349}
350
351fn write_fn_name(out: &mut String, s: &str) {
352    // Simplify associated type names so that e.g. `T::FooBar` becomes
353    // `visit_foo_bar`.
354    let s = s.splitn(2, "::").last().unwrap();
355    for c in s.chars() {
356        if c.is_ascii_uppercase() {
357            out.push('_');
358            out.push(c.to_ascii_lowercase());
359        } else {
360            out.push(c);
361        }
362    }
363}
364
365fn trait_generics(generics: &BTreeMap<String, BTreeSet<String>>) -> String {
366    generics.keys().map(|id| format!("{id}, ")).join("")
367}
368
369fn trait_generics_and_bounds(generics: &BTreeMap<String, BTreeSet<String>>) -> String {
370    generics
371        .iter()
372        .map(|(ident, bounds)| {
373            if bounds.len() == 0 {
374                format!("{ident}, ")
375            } else {
376                format!("{ident}: {}, ", bounds.iter().join("+"))
377            }
378        })
379        .join("")
380}
381
382fn item_generics(item: &Item, suffix: &str) -> String {
383    if item.generics().is_empty() {
384        "".into()
385    } else {
386        let generics = item
387            .generics()
388            .iter()
389            .map(|g| format!("{}{suffix}", g.name))
390            .join(", ");
391        format!("<{generics}>")
392    }
393}