1use std::collections::{BTreeMap, BTreeSet};
16
17use itertools::Itertools;
18use mz_ore_build::codegen::CodegenBuf;
19
20use crate::ir::{Ir, Item, Type};
21
22pub fn gen_fold(ir: &Ir) -> String {
27 gen_fold_root(ir)
28}
29
30pub fn gen_visit(ir: &Ir) -> String {
35 gen_visit_root(&VisitConfig { mutable: false }, ir)
36}
37
38pub fn gen_visit_mut(ir: &Ir) -> String {
43 gen_visit_root(&VisitConfig { mutable: true }, ir)
44}
45
46pub fn gen_fold_root(ir: &Ir) -> String {
47 let mut generics = BTreeMap::new();
48 for (name, bounds) in &ir.generics {
49 generics.insert(name.clone(), bounds.clone());
50 generics.insert(format!("{name}2"), bounds.clone());
51 }
52 let trait_generics = trait_generics(&generics);
53 let trait_generics_and_bounds = trait_generics_and_bounds(&generics);
54
55 let mut buf = CodegenBuf::new();
56
57 buf.write_block(
58 format!("pub trait Fold<{trait_generics_and_bounds}>"),
59 |buf| {
60 for (name, item) in &ir.items {
61 match item {
62 Item::Abstract => {
63 let name2 = name.replacen("::", "2::", 1);
67 let fn_name = fold_fn_name(name);
68 buf.writeln(format!("fn {fn_name}(&mut self, node: {name}) -> {name2};"))
69 }
70 Item::Struct(_) | Item::Enum(_) => {
71 let generics = item_generics(item, "");
72 let generics2 = item_generics(item, "2");
73 let fn_name = fold_fn_name(name);
74 buf.write_block(
75 format!("fn {fn_name}(&mut self, node: {name}{generics}) -> {name}{generics2}"),
76 |buf| buf.writeln(format!("{fn_name}(self, node)")),
77 );
78 }
79 }
80 }
81 },
82 );
83
84 buf.write_block(
85 format!("pub trait FoldNode<{trait_generics_and_bounds}>"),
86 |buf| {
87 buf.writeln("type Folded;");
88 buf.writeln(format!(
89 "fn fold<F: Fold<{trait_generics}>>(self, folder: &mut F) -> Self::Folded;"
90 ));
91 },
92 );
93
94 for (name, item) in &ir.items {
95 if let Item::Abstract = item {
96 continue;
97 }
98 let generics = item_generics(item, "");
99 let generics2 = item_generics(item, "2");
100 let fn_name = fold_fn_name(name);
101 buf.write_block(
102 format!(
103 "impl<{trait_generics_and_bounds}> FoldNode<{trait_generics}> for {name}{generics}"
104 ),
105 |buf| {
106 buf.writeln(format!("type Folded = {name}{generics2};"));
107 buf.write_block(
108 format!(
109 "fn fold<F: Fold<{trait_generics}>>(self, folder: &mut F) -> Self::Folded"
110 ),
111 |buf| buf.writeln(format!("folder.{fn_name}(self)")),
112 );
113 },
114 );
115
116 buf.writeln("#[allow(clippy::needless_pass_by_ref_mut)]");
117 buf.writeln(format!(
118 "pub fn {fn_name}<F, {trait_generics_and_bounds}>(folder: &mut F, node: {name}{generics}) -> {name}{generics2}"
119 ));
120 buf.writeln("where");
121 buf.writeln(format!(" F: Fold<{trait_generics}> + ?Sized,"));
122 buf.write_block("", |buf| match item {
123 Item::Struct(s) => {
124 buf.write_block(name, |buf| {
125 for (i, f) in s.fields.iter().enumerate() {
126 let field_name = match &f.name {
127 Some(name) => name.clone(),
128 None => i.to_string(),
129 };
130 let binding = format!("node.{field_name}");
131 buf.start_line();
132 buf.write(format!("{field_name}: "));
133 gen_fold_element(buf, &binding, &f.ty);
134 buf.write(",");
135 buf.end_line();
136 }
137 });
138 }
139 Item::Enum(e) => {
140 buf.write_block("match node", |buf| {
141 for v in &e.variants {
142 let vname = &v.name;
143 buf.write_block(format!("{name}::{vname}"), |buf| {
144 for (i, f) in v.fields.iter().enumerate() {
145 let name = f.name.clone().unwrap_or_else(|| i.to_string());
146 buf.writeln(format!("{name}: binding{i},"));
147 }
148 buf.restart_block("=>");
149 buf.write_block(format!("{name}::{vname}"), |buf| {
150 for (i, f) in v.fields.iter().enumerate() {
151 let field_name = match &f.name {
152 Some(name) => name.clone(),
153 None => i.to_string(),
154 };
155 let binding = format!("binding{i}");
156 buf.start_line();
157 buf.write(format!("{field_name}: "));
158 gen_fold_element(buf, &binding, &f.ty);
159 buf.write(",");
160 buf.end_line();
161 }
162 });
163 });
164 }
165 });
166 }
167 Item::Abstract => (),
168 });
169 }
170
171 buf.into_string()
172}
173
174fn gen_fold_element(buf: &mut CodegenBuf, binding: &str, ty: &Type) {
175 match ty {
176 Type::Primitive => buf.write(binding),
177 Type::Abstract(ty) => {
178 let fn_name = fold_fn_name(ty);
179 buf.write(format!("folder.{fn_name}({binding})"));
180 }
181 Type::Option(ty) => {
182 buf.write(format!("{binding}.map(|v| "));
183 gen_fold_element(buf, "v", ty);
184 buf.write(")")
185 }
186 Type::Vec(ty) => {
187 buf.write(format!("{binding}.into_iter().map(|v| "));
188 gen_fold_element(buf, "v", ty);
189 buf.write(").collect()");
190 }
191 Type::Box(ty) => {
192 buf.write("Box::new(");
193 gen_fold_element(buf, &format!("*{binding}"), ty);
194 buf.write(")");
195 }
196 Type::Local(s) => {
197 let fn_name = fold_fn_name(s);
198 buf.write(format!("folder.{fn_name}({binding})"));
199 }
200 Type::Map { key, value } => {
201 buf.write(format!(
202 "{{ std::collections::BTreeMap::from_iter({binding}.into_iter().map(|(k, v)| {{("
203 ));
204 gen_fold_element(buf, "k", key);
205 buf.write(".to_owned(), ");
206 gen_fold_element(buf, "v", value);
207 buf.write(".to_owned()) }) )}")
208 }
209 }
210}
211
212struct VisitConfig {
213 mutable: bool,
214}
215
216fn gen_visit_root(c: &VisitConfig, ir: &Ir) -> String {
217 let trait_name = if c.mutable { "VisitMut" } else { "Visit" };
218 let fn_name_base = if c.mutable { "visit_mut" } else { "visit" };
219 let muta = if c.mutable { "mut " } else { "" };
220 let trait_generics = trait_generics(&ir.generics);
221 let trait_generics_and_bounds = trait_generics_and_bounds(&ir.generics);
222
223 let mut buf = CodegenBuf::new();
224
225 buf.write_block(
226 format!("pub trait {trait_name}<'ast, {trait_generics_and_bounds}>"),
227 |buf| {
228 for (name, item) in &ir.items {
229 let generics = item_generics(item, "");
230 let fn_name = visit_fn_name(c, name);
231 buf.write_block(
232 format!("fn {fn_name}(&mut self, node: &'ast {muta}{name}{generics})"),
233 |buf| buf.writeln(format!("{fn_name}(self, node)")),
234 );
235 }
236 },
237 );
238
239 buf.write_block(format!(
240 "pub trait {trait_name}Node<'ast, {trait_generics_and_bounds}>"
241 ), |buf| buf.writeln(format!(
242 "fn {fn_name_base}<V: {trait_name}<'ast, {trait_generics}>>(&'ast {muta}self, visitor: &mut V);"
243 )));
244
245 for (name, item) in &ir.items {
246 let generics = item_generics(item, "");
247 let fn_name = visit_fn_name(c, name);
248 if !matches!(item, Item::Abstract) {
249 buf.write_block(format!(
250 "impl<'ast, {trait_generics_and_bounds}> {trait_name}Node<'ast, {trait_generics}> for {name}{generics}"
251 ), |buf| {
252 buf.write_block(format!(
253 "fn {fn_name_base}<V: {trait_name}<'ast, {trait_generics}>>(&'ast {muta}self, visitor: &mut V)"
254 ), |buf| buf.writeln(format!("visitor.{fn_name}(self)")));
255 });
256 }
257 buf.writeln("#[allow(clippy::needless_pass_by_ref_mut)]");
258 buf.writeln(format!(
259 "pub fn {fn_name}<'ast, V, {trait_generics_and_bounds}>(visitor: &mut V, node: &'ast {muta}{name}{generics})"
260 ));
261 buf.writeln("where");
262 buf.writeln(format!(
263 " V: {trait_name}<'ast, {trait_generics}> + ?Sized,"
264 ));
265 buf.write_block("", |buf| match item {
266 Item::Struct(s) => {
267 for (i, f) in s.fields.iter().enumerate() {
268 let binding = match &f.name {
269 Some(name) => format!("&{muta}node.{name}"),
270 None => format!("&{muta}node.{i}"),
271 };
272 gen_visit_element(c, buf, &binding, &f.ty);
273 }
274 }
275 Item::Enum(e) => {
276 buf.write_block("match node", |buf| {
277 for v in &e.variants {
278 let vname = &v.name;
279 buf.write_block(format!("{name}::{vname}"), |buf| {
280 for (i, f) in v.fields.iter().enumerate() {
281 let name = f.name.clone().unwrap_or_else(|| i.to_string());
282 buf.writeln(format!("{name}: binding{i},"));
283 }
284 buf.restart_block("=>");
285 for (i, f) in v.fields.iter().enumerate() {
286 let binding = format!("binding{i}");
287 gen_visit_element(c, buf, &binding, &f.ty);
288 }
289 });
290 }
291 });
292 }
293 Item::Abstract => (),
294 });
295 }
296
297 buf.into_string()
298}
299
300fn gen_visit_element(c: &VisitConfig, buf: &mut CodegenBuf, binding: &str, ty: &Type) {
301 match ty {
302 Type::Primitive => (),
303 Type::Abstract(ty) => {
304 let fn_name = visit_fn_name(c, ty);
305 buf.writeln(format!("visitor.{fn_name}({binding});"));
306 }
307 Type::Option(ty) => {
308 buf.write_block(format!("if let Some(v) = {binding}"), |buf| {
309 gen_visit_element(c, buf, "v", ty)
310 });
311 }
312 Type::Vec(ty) => {
313 buf.write_block(format!("for v in {binding}"), |buf| {
314 gen_visit_element(c, buf, "v", ty)
315 });
316 }
317 Type::Box(ty) => {
318 let binding = match c.mutable {
319 true => format!("&mut *{binding}"),
320 false => format!("&*{binding}"),
321 };
322 gen_visit_element(c, buf, &binding, ty);
323 }
324 Type::Local(s) => {
325 let fn_name = visit_fn_name(c, s);
326 buf.writeln(format!("visitor.{fn_name}({binding});"));
327 }
328 Type::Map { value, .. } => {
329 buf.write_block(format!("for (_, value) in {binding}"), |buf| {
330 gen_visit_element(c, buf, "value", value)
331 });
332 }
333 }
334}
335
336fn fold_fn_name(s: &str) -> String {
337 let mut out = String::from("fold");
338 write_fn_name(&mut out, s);
339 out
340}
341
342fn visit_fn_name(c: &VisitConfig, s: &str) -> String {
343 let mut out = String::from("visit");
344 write_fn_name(&mut out, s);
345 if c.mutable {
346 out.push_str("_mut");
347 }
348 out
349}
350
351fn write_fn_name(out: &mut String, s: &str) {
352 let s = s.splitn(2, "::").last().unwrap();
355 for c in s.chars() {
356 if c.is_ascii_uppercase() {
357 out.push('_');
358 out.push(c.to_ascii_lowercase());
359 } else {
360 out.push(c);
361 }
362 }
363}
364
365fn trait_generics(generics: &BTreeMap<String, BTreeSet<String>>) -> String {
366 generics.keys().map(|id| format!("{id}, ")).join("")
367}
368
369fn trait_generics_and_bounds(generics: &BTreeMap<String, BTreeSet<String>>) -> String {
370 generics
371 .iter()
372 .map(|(ident, bounds)| {
373 if bounds.len() == 0 {
374 format!("{ident}, ")
375 } else {
376 format!("{ident}: {}, ", bounds.iter().join("+"))
377 }
378 })
379 .join("")
380}
381
382fn item_generics(item: &Item, suffix: &str) -> String {
383 if item.generics().is_empty() {
384 "".into()
385 } else {
386 let generics = item
387 .generics()
388 .iter()
389 .map(|g| format!("{}{suffix}", g.name))
390 .join(", ");
391 format!("<{generics}>")
392 }
393}