mz_expr_parser/
parser.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Copyright Materialize, Inc. and contributors. All rights reserved.
11//
12// Use of this software is governed by the Business Source License
13// included in the LICENSE file.
14//
15// As of the Change Date specified in that file, in accordance with
16// the Business Source License, use of this software will be governed
17// by the Apache License, Version 2.0.
18
19use mz_ore::collections::CollectionExt;
20use proc_macro2::LineColumn;
21use syn::Error;
22use syn::parse::{Parse, ParseStream, Parser};
23use syn::spanned::Spanned;
24
25use super::TestCatalog;
26
27use self::util::*;
28
29/// Builds a [mz_expr::MirRelationExpr] from a string.
30pub fn try_parse_mir(catalog: &TestCatalog, s: &str) -> Result<mz_expr::MirRelationExpr, String> {
31    // Define a Parser that constructs a (read-only) parsing context `ctx` and
32    // delegates to `relation::parse_expr` by passing a `ctx` as a shared ref.
33    let parser = move |input: ParseStream| {
34        let ctx = Ctx { catalog };
35        relation::parse_expr(&ctx, input)
36    };
37    // Since the syn lexer doesn't parse comments, we replace all `// {`
38    // occurrences in the input string with `:: {`.
39    let s = s.replace("// {", ":: {");
40    // Call the parser with the given input string.
41    let mut expr = parser.parse_str(&s).map_err(|err| {
42        let (line, column) = (err.span().start().line, err.span().start().column);
43        format!("parse error at {line}:{column}:\n{err}\n")
44    })?;
45    // Fix the types of the local let bindings of the parsed expression in a
46    // post-processing pass.
47    relation::fix_types(&mut expr, &mut relation::FixTypesCtx::default())?;
48    // Return the parsed, post-processed expression.
49    Ok(expr)
50}
51
52/// Builds a source definition from a string.
53pub fn try_parse_def(catalog: &TestCatalog, s: &str) -> Result<Def, String> {
54    // Define a Parser that constructs a (read-only) parsing context `ctx` and
55    // delegates to `relation::parse_expr` by passing a `ctx` as a shared ref.
56    let parser = move |input: ParseStream| {
57        let ctx = Ctx { catalog };
58        def::parse_def(&ctx, input)
59    };
60    // Call the parser with the given input string.
61    let def = parser.parse_str(s).map_err(|err| {
62        let (line, column) = (err.span().start().line, err.span().start().column);
63        format!("parse error at {line}:{column}:\n{err}\n")
64    })?;
65    // Return the parsed, post-processed expression.
66    Ok(def)
67}
68
69/// Support for parsing [mz_expr::MirRelationExpr].
70mod relation {
71    use std::collections::BTreeMap;
72
73    use mz_expr::{AccessStrategy, Id, JoinImplementation, LocalId, MirRelationExpr};
74    use mz_repr::{Diff, RelationType, Row, ScalarType};
75
76    use crate::parser::analyses::Analyses;
77
78    use super::*;
79
80    type Result = syn::Result<MirRelationExpr>;
81
82    pub fn parse_expr(ctx: CtxRef, input: ParseStream) -> Result {
83        let lookahead = input.lookahead1();
84        if lookahead.peek(kw::Constant) {
85            parse_constant(ctx, input)
86        } else if lookahead.peek(kw::Get) {
87            parse_get(ctx, input)
88        } else if lookahead.peek(kw::Return) {
89            parse_let_or_letrec_old(ctx, input)
90        } else if lookahead.peek(kw::With) {
91            parse_let_or_letrec(ctx, input)
92        } else if lookahead.peek(kw::Project) {
93            parse_project(ctx, input)
94        } else if lookahead.peek(kw::Map) {
95            parse_map(ctx, input)
96        } else if lookahead.peek(kw::FlatMap) {
97            parse_flat_map(ctx, input)
98        } else if lookahead.peek(kw::Filter) {
99            parse_filter(ctx, input)
100        } else if lookahead.peek(kw::CrossJoin) {
101            parse_cross_join(ctx, input)
102        } else if lookahead.peek(kw::Join) {
103            parse_join(ctx, input)
104        } else if lookahead.peek(kw::Distinct) {
105            parse_distinct(ctx, input)
106        } else if lookahead.peek(kw::Reduce) {
107            parse_reduce(ctx, input)
108        } else if lookahead.peek(kw::TopK) {
109            parse_top_k(ctx, input)
110        } else if lookahead.peek(kw::Negate) {
111            parse_negate(ctx, input)
112        } else if lookahead.peek(kw::Threshold) {
113            parse_threshold(ctx, input)
114        } else if lookahead.peek(kw::Union) {
115            parse_union(ctx, input)
116        } else if lookahead.peek(kw::ArrangeBy) {
117            parse_arrange_by(ctx, input)
118        } else {
119            Err(lookahead.error())
120        }
121    }
122
123    fn parse_constant(ctx: CtxRef, input: ParseStream) -> Result {
124        let constant = input.parse::<kw::Constant>()?;
125
126        let parse_typ = |input: ParseStream| -> syn::Result<RelationType> {
127            let analyses = analyses::parse_analyses(input)?;
128            let Some(column_types) = analyses.types else {
129                let msg = "Missing expected `types` analyses for Constant line";
130                Err(Error::new(input.span(), msg))?
131            };
132            let keys = analyses.keys.unwrap_or_default();
133            Ok(RelationType { column_types, keys })
134        };
135
136        if input.eat3(syn::Token![<], kw::empty, syn::Token![>]) {
137            let typ = parse_typ(input)?;
138            Ok(MirRelationExpr::constant(vec![], typ))
139        } else {
140            let typ = parse_typ(input)?;
141            let parse_children = ParseChildren::new(input, constant.span().start());
142            let rows = Ok(parse_children.parse_many(ctx, parse_constant_entry)?);
143            Ok(MirRelationExpr::Constant { rows, typ })
144        }
145    }
146
147    fn parse_constant_entry(_ctx: CtxRef, input: ParseStream) -> syn::Result<(Row, Diff)> {
148        input.parse::<syn::Token![-]>()?;
149
150        let (row, diff);
151
152        let inner1;
153        syn::parenthesized!(inner1 in input);
154
155        if inner1.peek(syn::token::Paren) {
156            let inner2;
157            syn::parenthesized!(inner2 in inner1);
158            row = inner2.parse::<Parsed<Row>>()?.into();
159            inner1.parse::<kw::x>()?;
160            diff = match inner1.parse::<syn::Lit>()? {
161                syn::Lit::Int(l) => Ok(l.base10_parse::<Diff>()?),
162                _ => Err(Error::new(inner1.span(), "expected Diff literal")),
163            }?;
164        } else {
165            row = inner1.parse::<Parsed<Row>>()?.into();
166            diff = Diff::ONE;
167        }
168
169        Ok((row, diff))
170    }
171
172    fn parse_get(ctx: CtxRef, input: ParseStream) -> Result {
173        input.parse::<kw::Get>()?;
174
175        let ident = input.parse::<syn::Ident>()?;
176        match ctx.catalog.get(&ident.to_string()) {
177            Some((id, _cols, typ)) => Ok(MirRelationExpr::Get {
178                id: Id::Global(*id),
179                typ: typ.clone(),
180                access_strategy: AccessStrategy::UnknownOrLocal,
181            }),
182            None => Ok(MirRelationExpr::Get {
183                id: Id::Local(parse_local_id(ident)?),
184                typ: RelationType::empty(),
185                access_strategy: AccessStrategy::UnknownOrLocal,
186            }),
187        }
188    }
189
190    /// Parses a Let or a LetRec with the old order: Return first, and then CTEs in descending order.
191    fn parse_let_or_letrec_old(ctx: CtxRef, input: ParseStream) -> Result {
192        let return_ = input.parse::<kw::Return>()?;
193        let parse_body = ParseChildren::new(input, return_.span().start());
194        let body = parse_body.parse_one(ctx, parse_expr)?;
195
196        let with = input.parse::<kw::With>()?;
197        let recursive = input.eat2(kw::Mutually, kw::Recursive);
198        let parse_ctes = ParseChildren::new(input, with.span().start());
199        let mut ctes = parse_ctes.parse_many(ctx, parse_cte)?;
200
201        if ctes.is_empty() {
202            let msg = "At least one Let/LetRec cte binding expected";
203            Err(Error::new(input.span(), msg))?
204        }
205
206        ctes.reverse();
207        let cte_ids = ctes.iter().map(|(id, _, _)| id);
208        if !cte_ids.clone().is_sorted() {
209            let msg = format!(
210                "Error parsing Let/LetRec: seen Return before With, but cte ids are not ordered descending: {:?}",
211                cte_ids.collect::<Vec<_>>()
212            );
213            Err(Error::new(input.span(), msg))?
214        }
215        build_let_or_let_rec(ctes, body, recursive, with)
216    }
217
218    /// Parses a Let or a LetRec with the new order: CTEs first in ascending order, and then Return.
219    fn parse_let_or_letrec(ctx: CtxRef, input: ParseStream) -> Result {
220        let with = input.parse::<kw::With>()?;
221        let recursive = input.eat2(kw::Mutually, kw::Recursive);
222        let parse_ctes = ParseChildren::new(input, with.span().start());
223        let ctes = parse_ctes.parse_many(ctx, parse_cte)?;
224
225        let return_ = input.parse::<kw::Return>()?;
226        let parse_body = ParseChildren::new(input, return_.span().start());
227        let body = parse_body.parse_one(ctx, parse_expr)?;
228
229        if ctes.is_empty() {
230            let msg = "At least one `let cte` binding expected";
231            Err(Error::new(input.span(), msg))?
232        }
233
234        let cte_ids = ctes.iter().map(|(id, _, _)| id);
235        if !cte_ids.clone().is_sorted() {
236            let msg = format!(
237                "Error parsing Let/LetRec: seen With before Return, but cte ids are not ordered ascending: {:?}",
238                cte_ids.collect::<Vec<_>>()
239            );
240            Err(Error::new(input.span(), msg))?
241        }
242        build_let_or_let_rec(ctes, body, recursive, with)
243    }
244
245    fn build_let_or_let_rec(
246        ctes: Vec<(LocalId, Analyses, MirRelationExpr)>,
247        body: MirRelationExpr,
248        recursive: bool,
249        with: kw::With,
250    ) -> Result {
251        if recursive {
252            let (mut ids, mut values, mut limits) = (vec![], vec![], vec![]);
253            for (id, analyses, value) in ctes.into_iter() {
254                let typ = {
255                    let Some(column_types) = analyses.types else {
256                        let msg = format!("`let {}` needs a `types` analyses", id);
257                        Err(Error::new(with.span(), msg))?
258                    };
259                    let keys = analyses.keys.unwrap_or_default();
260                    RelationType { column_types, keys }
261                };
262
263                // An ugly-ugly hack to pass the type information of the WMR CTE
264                // to the `fix_types` pass.
265                let value = {
266                    let get_cte = MirRelationExpr::Get {
267                        id: Id::Local(id),
268                        typ: typ.clone(),
269                        access_strategy: AccessStrategy::UnknownOrLocal,
270                    };
271                    // Do not use the `union` smart constructor here!
272                    MirRelationExpr::Union {
273                        base: Box::new(get_cte),
274                        inputs: vec![value],
275                    }
276                };
277
278                ids.push(id);
279                values.push(value);
280                limits.push(None); // TODO: support limits
281            }
282
283            Ok(MirRelationExpr::LetRec {
284                ids,
285                values,
286                limits,
287                body: Box::new(body),
288            })
289        } else {
290            let mut body = body;
291            for (id, _, value) in ctes.into_iter().rev() {
292                body = MirRelationExpr::Let {
293                    id,
294                    value: Box::new(value),
295                    body: Box::new(body),
296                };
297            }
298            Ok(body)
299        }
300    }
301
302    fn parse_cte(
303        ctx: CtxRef,
304        input: ParseStream,
305    ) -> syn::Result<(LocalId, analyses::Analyses, MirRelationExpr)> {
306        let cte = input.parse::<kw::cte>()?;
307
308        let ident = input.parse::<syn::Ident>()?;
309        let id = parse_local_id(ident)?;
310
311        input.parse::<syn::Token![=]>()?;
312
313        let analyses = analyses::parse_analyses(input)?;
314
315        let parse_value = ParseChildren::new(input, cte.span().start());
316        let value = parse_value.parse_one(ctx, parse_expr)?;
317
318        Ok((id, analyses, value))
319    }
320
321    fn parse_project(ctx: CtxRef, input: ParseStream) -> Result {
322        let project = input.parse::<kw::Project>()?;
323
324        let content;
325        syn::parenthesized!(content in input);
326        let outputs = content.parse_comma_sep(scalar::parse_column_index)?;
327        let parse_input = ParseChildren::new(input, project.span().start());
328        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
329
330        Ok(MirRelationExpr::Project { input, outputs })
331    }
332
333    fn parse_map(ctx: CtxRef, input: ParseStream) -> Result {
334        let map = input.parse::<kw::Map>()?;
335
336        let scalars = {
337            let inner;
338            syn::parenthesized!(inner in input);
339            scalar::parse_exprs(&inner)?
340        };
341
342        let parse_input = ParseChildren::new(input, map.span().start());
343        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
344
345        Ok(MirRelationExpr::Map { input, scalars })
346    }
347
348    fn parse_flat_map(ctx: CtxRef, input: ParseStream) -> Result {
349        use mz_expr::TableFunc::*;
350
351        let flat_map = input.parse::<kw::FlatMap>()?;
352
353        let ident = input.parse::<syn::Ident>()?;
354        let func = match ident.to_string().to_lowercase().as_str() {
355            "unnest_list" => UnnestList {
356                el_typ: ScalarType::Int64, // FIXME
357            },
358            "unnest_array" => UnnestArray {
359                el_typ: ScalarType::Int64, // FIXME
360            },
361            "wrap1" => Wrap {
362                types: vec![
363                    ScalarType::Int64.nullable(true), // FIXME
364                ],
365                width: 1,
366            },
367            "wrap2" => Wrap {
368                types: vec![
369                    ScalarType::Int64.nullable(true), // FIXME
370                    ScalarType::Int64.nullable(true), // FIXME
371                ],
372                width: 2,
373            },
374            "wrap3" => Wrap {
375                types: vec![
376                    ScalarType::Int64.nullable(true), // FIXME
377                    ScalarType::Int64.nullable(true), // FIXME
378                    ScalarType::Int64.nullable(true), // FIXME
379                ],
380                width: 3,
381            },
382            "generate_series" => GenerateSeriesInt64,
383            "jsonb_object_keys" => JsonbObjectKeys,
384            _ => Err(Error::new(ident.span(), "unsupported function name"))?,
385        };
386
387        let exprs = {
388            let inner;
389            syn::parenthesized!(inner in input);
390            scalar::parse_exprs(&inner)?
391        };
392
393        let parse_input = ParseChildren::new(input, flat_map.span().start());
394        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
395
396        Ok(MirRelationExpr::FlatMap { input, func, exprs })
397    }
398
399    fn parse_filter(ctx: CtxRef, input: ParseStream) -> Result {
400        use mz_expr::MirScalarExpr::CallVariadic;
401        use mz_expr::VariadicFunc::And;
402
403        let filter = input.parse::<kw::Filter>()?;
404
405        let predicates = match scalar::parse_expr(input)? {
406            CallVariadic { func: And, exprs } => exprs,
407            expr => vec![expr],
408        };
409
410        let parse_input = ParseChildren::new(input, filter.span().start());
411        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
412
413        Ok(MirRelationExpr::Filter { input, predicates })
414    }
415
416    fn parse_cross_join(ctx: CtxRef, input: ParseStream) -> Result {
417        let join = input.parse::<kw::CrossJoin>()?;
418
419        let parse_inputs = ParseChildren::new(input, join.span().start());
420        let inputs = parse_inputs.parse_many(ctx, parse_expr)?;
421
422        Ok(MirRelationExpr::Join {
423            inputs,
424            equivalences: vec![],
425            implementation: JoinImplementation::Unimplemented,
426        })
427    }
428
429    fn parse_join(ctx: CtxRef, input: ParseStream) -> Result {
430        let join = input.parse::<kw::Join>()?;
431
432        input.parse::<kw::on>()?;
433        input.parse::<syn::Token![=]>()?;
434        let inner;
435        syn::parenthesized!(inner in input);
436        let equivalences = scalar::parse_join_equivalences(&inner)?;
437
438        let parse_inputs = ParseChildren::new(input, join.span().start());
439        let inputs = parse_inputs.parse_many(ctx, parse_expr)?;
440
441        Ok(MirRelationExpr::Join {
442            inputs,
443            equivalences,
444            implementation: JoinImplementation::Unimplemented,
445        })
446    }
447
448    fn parse_distinct(ctx: CtxRef, input: ParseStream) -> Result {
449        let reduce = input.parse::<kw::Distinct>()?;
450
451        let group_key = if input.eat(kw::project) {
452            input.parse::<syn::Token![=]>()?;
453            let inner;
454            syn::bracketed!(inner in input);
455            inner.parse_comma_sep(scalar::parse_expr)?
456        } else {
457            vec![]
458        };
459
460        let monotonic = input.eat(kw::monotonic);
461
462        let expected_group_size = if input.eat(kw::exp_group_size) {
463            input.parse::<syn::Token![=]>()?;
464            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
465        } else {
466            None
467        };
468
469        let parse_inputs = ParseChildren::new(input, reduce.span().start());
470        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
471
472        Ok(MirRelationExpr::Reduce {
473            input,
474            group_key,
475            aggregates: vec![],
476            monotonic,
477            expected_group_size,
478        })
479    }
480
481    fn parse_reduce(ctx: CtxRef, input: ParseStream) -> Result {
482        let reduce = input.parse::<kw::Reduce>()?;
483
484        let group_key = if input.eat(kw::group_by) {
485            input.parse::<syn::Token![=]>()?;
486            let inner;
487            syn::bracketed!(inner in input);
488            inner.parse_comma_sep(scalar::parse_expr)?
489        } else {
490            vec![]
491        };
492
493        let aggregates = {
494            input.parse::<kw::aggregates>()?;
495            input.parse::<syn::Token![=]>()?;
496            let inner;
497            syn::bracketed!(inner in input);
498            inner.parse_comma_sep(aggregate::parse_expr)?
499        };
500
501        let monotonic = input.eat(kw::monotonic);
502
503        let expected_group_size = if input.eat(kw::exp_group_size) {
504            input.parse::<syn::Token![=]>()?;
505            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
506        } else {
507            None
508        };
509
510        let parse_inputs = ParseChildren::new(input, reduce.span().start());
511        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
512
513        Ok(MirRelationExpr::Reduce {
514            input,
515            group_key,
516            aggregates,
517            monotonic,
518            expected_group_size,
519        })
520    }
521
522    fn parse_top_k(ctx: CtxRef, input: ParseStream) -> Result {
523        let top_k = input.parse::<kw::TopK>()?;
524
525        let group_key = if input.eat(kw::group_by) {
526            input.parse::<syn::Token![=]>()?;
527            let inner;
528            syn::bracketed!(inner in input);
529            inner.parse_comma_sep(scalar::parse_column_index)?
530        } else {
531            vec![]
532        };
533
534        let order_key = if input.eat(kw::order_by) {
535            input.parse::<syn::Token![=]>()?;
536            let inner;
537            syn::bracketed!(inner in input);
538            inner.parse_comma_sep(scalar::parse_column_order)?
539        } else {
540            vec![]
541        };
542
543        let limit = if input.eat(kw::limit) {
544            input.parse::<syn::Token![=]>()?;
545            Some(scalar::parse_expr(input)?)
546        } else {
547            None
548        };
549
550        let offset = if input.eat(kw::offset) {
551            input.parse::<syn::Token![=]>()?;
552            input.parse::<syn::LitInt>()?.base10_parse::<usize>()?
553        } else {
554            0
555        };
556
557        let monotonic = input.eat(kw::monotonic);
558
559        let expected_group_size = if input.eat(kw::exp_group_size) {
560            input.parse::<syn::Token![=]>()?;
561            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
562        } else {
563            None
564        };
565
566        let parse_inputs = ParseChildren::new(input, top_k.span().start());
567        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
568
569        Ok(MirRelationExpr::TopK {
570            input,
571            group_key,
572            order_key,
573            limit,
574            offset,
575            monotonic,
576            expected_group_size,
577        })
578    }
579
580    fn parse_negate(ctx: CtxRef, input: ParseStream) -> Result {
581        let negate = input.parse::<kw::Negate>()?;
582
583        let parse_input = ParseChildren::new(input, negate.span().start());
584        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
585
586        Ok(MirRelationExpr::Negate { input })
587    }
588
589    fn parse_threshold(ctx: CtxRef, input: ParseStream) -> Result {
590        let threshold = input.parse::<kw::Threshold>()?;
591
592        let parse_input = ParseChildren::new(input, threshold.span().start());
593        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
594
595        Ok(MirRelationExpr::Threshold { input })
596    }
597
598    fn parse_union(ctx: CtxRef, input: ParseStream) -> Result {
599        let union = input.parse::<kw::Union>()?;
600
601        let parse_inputs = ParseChildren::new(input, union.span().start());
602        let mut children = parse_inputs.parse_many(ctx, parse_expr)?;
603        let inputs = children.split_off(1);
604        let base = Box::new(children.into_element());
605
606        Ok(MirRelationExpr::Union { base, inputs })
607    }
608
609    fn parse_arrange_by(ctx: CtxRef, input: ParseStream) -> Result {
610        let arrange_by = input.parse::<kw::ArrangeBy>()?;
611
612        let keys = {
613            input.parse::<kw::keys>()?;
614            input.parse::<syn::Token![=]>()?;
615            let inner;
616            syn::bracketed!(inner in input);
617            inner.parse_comma_sep(|input| {
618                let inner;
619                syn::bracketed!(inner in input);
620                scalar::parse_exprs(&inner)
621            })?
622        };
623
624        let parse_input = ParseChildren::new(input, arrange_by.span().start());
625        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
626
627        Ok(MirRelationExpr::ArrangeBy { input, keys })
628    }
629
630    fn parse_local_id(ident: syn::Ident) -> syn::Result<LocalId> {
631        if ident.to_string().starts_with('l') {
632            let n = ident.to_string()[1..]
633                .parse::<u64>()
634                .map_err(|err| Error::new(ident.span(), err.to_string()))?;
635            Ok(mz_expr::LocalId::new(n))
636        } else {
637            Err(Error::new(ident.span(), "invalid LocalId"))
638        }
639    }
640
641    #[derive(Default)]
642    pub struct FixTypesCtx {
643        env: BTreeMap<LocalId, RelationType>,
644        typ: Vec<RelationType>,
645    }
646
647    pub fn fix_types(
648        expr: &mut MirRelationExpr,
649        ctx: &mut FixTypesCtx,
650    ) -> std::result::Result<(), String> {
651        match expr {
652            MirRelationExpr::Let { id, value, body } => {
653                fix_types(value, ctx)?;
654                let value_typ = ctx.typ.pop().expect("value type");
655                let prior_typ = ctx.env.insert(id.clone(), value_typ);
656                fix_types(body, ctx)?;
657                ctx.env.remove(id);
658                if let Some(prior_typ) = prior_typ {
659                    ctx.env.insert(id.clone(), prior_typ);
660                }
661            }
662            MirRelationExpr::LetRec {
663                ids,
664                values,
665                body,
666                limits: _,
667            } => {
668                // An ugly-ugly hack to pass the type information of the WMR CTE
669                // to the `fix_types` pass.
670                let mut prior_typs = BTreeMap::default();
671                for (id, value) in std::iter::zip(ids.iter_mut(), values.iter_mut()) {
672                    let MirRelationExpr::Union { base, mut inputs } = value.take_dangerous() else {
673                        unreachable!("ensured by construction");
674                    };
675                    let MirRelationExpr::Get { id: _, typ, .. } = *base else {
676                        unreachable!("ensured by construction");
677                    };
678                    if let Some(prior_typ) = ctx.env.insert(id.clone(), typ) {
679                        prior_typs.insert(id.clone(), prior_typ);
680                    }
681                    *value = inputs.pop().expect("ensured by construction");
682                }
683                for value in values.iter_mut() {
684                    fix_types(value, ctx)?;
685                }
686                fix_types(body, ctx)?;
687                for id in ids.iter() {
688                    ctx.env.remove(id);
689                    if let Some(prior_typ) = prior_typs.remove(id) {
690                        ctx.env.insert(id.clone(), prior_typ);
691                    }
692                }
693            }
694            MirRelationExpr::Get {
695                id: Id::Local(id),
696                typ,
697                ..
698            } => {
699                let env_typ = match ctx.env.get(&*id) {
700                    Some(env_typ) => env_typ,
701                    None => Err(format!("Cannot fix type of unbound CTE {}", id))?,
702                };
703                *typ = env_typ.clone();
704                ctx.typ.push(env_typ.clone());
705            }
706            _ => {
707                for input in expr.children_mut() {
708                    fix_types(input, ctx)?;
709                }
710                let input_types = ctx.typ.split_off(ctx.typ.len() - expr.num_inputs());
711                ctx.typ.push(expr.typ_with_input_types(&input_types));
712            }
713        };
714
715        Ok(())
716    }
717}
718
719/// Support for parsing [mz_expr::MirScalarExpr].
720mod scalar {
721    use mz_expr::{
722        BinaryFunc, ColumnOrder, MirScalarExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
723    };
724    use mz_repr::{AsColumnType, Datum, Row, RowArena, ScalarType};
725
726    use super::*;
727
728    type Result = syn::Result<MirScalarExpr>;
729
730    pub fn parse_exprs(input: ParseStream) -> syn::Result<Vec<MirScalarExpr>> {
731        input.parse_comma_sep(parse_expr)
732    }
733
734    /// Parses a single expression.
735    ///
736    /// Because in EXPLAIN contexts parentheses might be optional, we need to
737    /// correctly handle operator precedence of infix operators.
738    ///
739    /// Currently, this works in two steps:
740    ///
741    /// 1. Convert the original infix expression to a postfix expression using
742    ///    an adapted variant of this [algorithm] with precedence taken from the
743    ///    Postgres [precedence] docs. Parenthesized operands are parsed in one
744    ///    step, so steps (3-4) from the [algorithm] are not needed here.
745    /// 2. Convert the postfix vector into a single [MirScalarExpr].
746    ///
747    /// [algorithm]: <https://www.prepbytes.com/blog/stacks/infix-to-postfix-conversion-using-stack/>
748    /// [precedence]: <https://www.postgresql.org/docs/7.2/sql-precedence.html>
749    pub fn parse_expr(input: ParseStream) -> Result {
750        let line = input.span().start().line;
751
752        /// Helper struct to keep track of the parsing state.
753        #[derive(Debug)]
754        enum Op {
755            Unr(mz_expr::UnaryFunc), // unary
756            Neg(mz_expr::UnaryFunc), // negated unary (append -.not() on fold)
757            Bin(mz_expr::BinaryFunc),
758            Var(mz_expr::VariadicFunc),
759        }
760
761        impl Op {
762            fn precedence(&self) -> Option<usize> {
763                match self {
764                    // 01: logical disjunction
765                    Op::Var(mz_expr::VariadicFunc::Or) => Some(1),
766                    // 02: logical conjunction
767                    Op::Var(mz_expr::VariadicFunc::And) => Some(2),
768                    // 04: equality, assignment
769                    Op::Bin(mz_expr::BinaryFunc::Eq) => Some(4),
770                    Op::Bin(mz_expr::BinaryFunc::NotEq) => Some(4),
771                    // 05: less than, greater than
772                    Op::Bin(mz_expr::BinaryFunc::Gt) => Some(5),
773                    Op::Bin(mz_expr::BinaryFunc::Gte) => Some(5),
774                    Op::Bin(mz_expr::BinaryFunc::Lt) => Some(5),
775                    Op::Bin(mz_expr::BinaryFunc::Lte) => Some(5),
776                    // 13: test for TRUE, FALSE, UNKNOWN, NULL
777                    Op::Unr(mz_expr::UnaryFunc::IsNull(_)) => Some(13),
778                    Op::Neg(mz_expr::UnaryFunc::IsNull(_)) => Some(13),
779                    Op::Unr(mz_expr::UnaryFunc::IsTrue(_)) => Some(13),
780                    Op::Neg(mz_expr::UnaryFunc::IsTrue(_)) => Some(13),
781                    Op::Unr(mz_expr::UnaryFunc::IsFalse(_)) => Some(13),
782                    Op::Neg(mz_expr::UnaryFunc::IsFalse(_)) => Some(13),
783                    // 14: addition, subtraction
784                    Op::Bin(mz_expr::BinaryFunc::AddInt64) => Some(14),
785                    // 14: multiplication, division, modulo
786                    Op::Bin(mz_expr::BinaryFunc::MulInt64) => Some(15),
787                    Op::Bin(mz_expr::BinaryFunc::DivInt64) => Some(15),
788                    Op::Bin(mz_expr::BinaryFunc::ModInt64) => Some(15),
789                    // unsupported
790                    _ => None,
791                }
792            }
793        }
794
795        /// Helper struct for entries in the postfix vector.
796        #[derive(Debug)]
797        enum Entry {
798            Operand(MirScalarExpr),
799            Operator(Op),
800        }
801
802        let mut opstack = vec![];
803        let mut postfix = vec![];
804        let mut exp_opd = true; // expects an argument of an operator
805
806        // Scan the given infix expression from left to right.
807        while !input.is_empty() && input.span().start().line == line {
808            // Operands and operators alternate.
809            if exp_opd {
810                postfix.push(Entry::Operand(parse_operand(input)?));
811                exp_opd = false;
812            } else {
813                // If the current symbol is an operator, then bind it to op.
814                // Else it is an operand - append it to postfix and continue.
815                let op = if input.eat(syn::Token![=]) {
816                    exp_opd = true;
817                    Op::Bin(mz_expr::BinaryFunc::Eq)
818                } else if input.eat(syn::Token![!=]) {
819                    exp_opd = true;
820                    Op::Bin(mz_expr::BinaryFunc::NotEq)
821                } else if input.eat(syn::Token![>=]) {
822                    exp_opd = true;
823                    Op::Bin(mz_expr::BinaryFunc::Gte)
824                } else if input.eat(syn::Token![>]) {
825                    exp_opd = true;
826                    Op::Bin(mz_expr::BinaryFunc::Gt)
827                } else if input.eat(syn::Token![<=]) {
828                    exp_opd = true;
829                    Op::Bin(mz_expr::BinaryFunc::Lte)
830                } else if input.eat(syn::Token![<]) {
831                    exp_opd = true;
832                    Op::Bin(mz_expr::BinaryFunc::Lt)
833                } else if input.eat(syn::Token![+]) {
834                    exp_opd = true;
835                    Op::Bin(mz_expr::BinaryFunc::AddInt64) // TODO: fix placeholder
836                } else if input.eat(syn::Token![*]) {
837                    exp_opd = true;
838                    Op::Bin(mz_expr::BinaryFunc::MulInt64) // TODO: fix placeholder
839                } else if input.eat(syn::Token![/]) {
840                    exp_opd = true;
841                    Op::Bin(mz_expr::BinaryFunc::DivInt64) // TODO: fix placeholder
842                } else if input.eat(syn::Token![%]) {
843                    exp_opd = true;
844                    Op::Bin(mz_expr::BinaryFunc::ModInt64) // TODO: fix placeholder
845                } else if input.eat(kw::AND) {
846                    exp_opd = true;
847                    Op::Var(mz_expr::VariadicFunc::And)
848                } else if input.eat(kw::OR) {
849                    exp_opd = true;
850                    Op::Var(mz_expr::VariadicFunc::Or)
851                } else if input.eat(kw::IS) {
852                    let negate = input.eat(kw::NOT);
853
854                    let lookahead = input.lookahead1();
855                    let func = if input.look_and_eat(kw::NULL, &lookahead) {
856                        mz_expr::func::IsNull.into()
857                    } else if input.look_and_eat(kw::TRUE, &lookahead) {
858                        mz_expr::func::IsTrue.into()
859                    } else if input.look_and_eat(kw::FALSE, &lookahead) {
860                        mz_expr::func::IsFalse.into()
861                    } else {
862                        Err(lookahead.error())?
863                    };
864
865                    if negate { Op::Neg(func) } else { Op::Unr(func) }
866                } else {
867                    // We were expecting an optional operator but didn't find
868                    // anything. Exit the parsing loop and process the postfix
869                    // vector.
870                    break;
871                };
872
873                // First, pop the operators which are already on the opstack that
874                // have higher or equal precedence than the current operator and
875                // append them to the postfix.
876                while opstack
877                    .last()
878                    .map(|op1: &Op| op1.precedence() >= op.precedence())
879                    .unwrap_or(false)
880                {
881                    let op1 = opstack.pop().expect("non-empty opstack");
882                    postfix.push(Entry::Operator(op1));
883                }
884
885                // Then push the op from this iteration onto the stack.
886                opstack.push(op);
887            }
888        }
889
890        // Pop all remaining symbols from opstack and append them to postfix.
891        postfix.extend(opstack.into_iter().rev().map(Entry::Operator));
892
893        if postfix.is_empty() {
894            let msg = "Cannot parse an empty expression";
895            Err(Error::new(input.span(), msg))?
896        }
897
898        // Flatten the postfix vector into a single MirScalarExpr.
899        let mut stack = vec![];
900        postfix.reverse();
901        while let Some(entry) = postfix.pop() {
902            match entry {
903                Entry::Operand(expr) => {
904                    stack.push(expr);
905                }
906                Entry::Operator(Op::Unr(func)) => {
907                    let expr = Box::new(stack.pop().expect("non-empty stack"));
908                    stack.push(MirScalarExpr::CallUnary { func, expr });
909                }
910                Entry::Operator(Op::Neg(func)) => {
911                    let expr = Box::new(stack.pop().expect("non-empty stack"));
912                    stack.push(MirScalarExpr::CallUnary { func, expr }.not());
913                }
914                Entry::Operator(Op::Bin(func)) => {
915                    let expr2 = Box::new(stack.pop().expect("non-empty stack"));
916                    let expr1 = Box::new(stack.pop().expect("non-empty stack"));
917                    stack.push(MirScalarExpr::CallBinary { func, expr1, expr2 });
918                }
919                Entry::Operator(Op::Var(func)) => {
920                    let expr2 = stack.pop().expect("non-empty stack");
921                    let expr1 = stack.pop().expect("non-empty stack");
922                    let mut exprs = vec![];
923                    for expr in [expr1, expr2] {
924                        match expr {
925                            MirScalarExpr::CallVariadic { func: f, exprs: es } if f == func => {
926                                exprs.extend(es.into_iter());
927                            }
928                            expr => {
929                                exprs.push(expr);
930                            }
931                        }
932                    }
933                    stack.push(MirScalarExpr::CallVariadic { func, exprs });
934                }
935            }
936        }
937
938        if stack.len() != 1 {
939            let msg = "Cannot fold postfix vector into a single MirScalarExpr";
940            Err(Error::new(input.span(), msg))?
941        }
942
943        Ok(stack.pop().unwrap())
944    }
945
946    pub fn parse_operand(input: ParseStream) -> Result {
947        let lookahead = input.lookahead1();
948        if lookahead.peek(syn::Token![#]) {
949            parse_column(input)
950        } else if lookahead.peek(syn::Lit) || lookahead.peek(kw::null) {
951            parse_literal_ok(input)
952        } else if lookahead.peek(kw::error) {
953            parse_literal_err(input)
954        } else if lookahead.peek(kw::array) {
955            parse_array(input)
956        } else if lookahead.peek(kw::list) {
957            parse_list(input)
958        } else if lookahead.peek(syn::Ident) {
959            parse_apply(input)
960        } else if lookahead.peek(syn::token::Brace) {
961            let inner;
962            syn::braced!(inner in input);
963            parse_literal_array(&inner)
964        } else if lookahead.peek(syn::token::Bracket) {
965            let inner;
966            syn::bracketed!(inner in input);
967            parse_literal_list(&inner)
968        } else if lookahead.peek(syn::token::Paren) {
969            let inner;
970            syn::parenthesized!(inner in input);
971            parse_expr(&inner)
972        } else {
973            Err(lookahead.error()) // FIXME: support IfThenElse variants
974        }
975    }
976
977    pub fn parse_column(input: ParseStream) -> Result {
978        Ok(MirScalarExpr::column(parse_column_index(input)?))
979    }
980
981    pub fn parse_column_index(input: ParseStream) -> syn::Result<usize> {
982        input.parse::<syn::Token![#]>()?;
983        input.parse::<syn::LitInt>()?.base10_parse::<usize>()
984    }
985
986    pub fn parse_column_order(input: ParseStream) -> syn::Result<ColumnOrder> {
987        input.parse::<syn::Token![#]>()?;
988        let column = input.parse::<syn::LitInt>()?.base10_parse::<usize>()?;
989        let desc = input.eat(kw::desc) || !input.eat(kw::asc);
990        let nulls_last = input.eat(kw::nulls_last) || !input.eat(kw::nulls_first);
991        Ok(ColumnOrder {
992            column,
993            desc,
994            nulls_last,
995        })
996    }
997
998    fn parse_literal_ok(input: ParseStream) -> Result {
999        let mut row = Row::default();
1000        let mut packer = row.packer();
1001
1002        let typ = if input.eat(kw::null) {
1003            packer.push(Datum::Null);
1004            input.parse::<syn::Token![::]>()?;
1005            analyses::parse_scalar_type(input)?.nullable(true)
1006        } else {
1007            match input.parse::<syn::Lit>()? {
1008                syn::Lit::Str(l) => {
1009                    packer.push(Datum::from(l.value().as_str()));
1010                    Ok(String::as_column_type())
1011                }
1012                syn::Lit::Int(l) => {
1013                    packer.push(Datum::from(l.base10_parse::<i64>()?));
1014                    Ok(i64::as_column_type())
1015                }
1016                syn::Lit::Float(l) => {
1017                    packer.push(Datum::from(l.base10_parse::<f64>()?));
1018                    Ok(f64::as_column_type())
1019                }
1020                syn::Lit::Bool(l) => {
1021                    packer.push(Datum::from(l.value));
1022                    Ok(bool::as_column_type())
1023                }
1024                _ => Err(Error::new(input.span(), "cannot parse literal")),
1025            }?
1026        };
1027
1028        Ok(MirScalarExpr::Literal(Ok(row), typ))
1029    }
1030
1031    fn parse_literal_err(input: ParseStream) -> Result {
1032        input.parse::<kw::error>()?;
1033        let mut msg = {
1034            let content;
1035            syn::parenthesized!(content in input);
1036            content.parse::<syn::LitStr>()?.value()
1037        };
1038        let err = if msg.starts_with("internal error: ") {
1039            Ok(mz_expr::EvalError::Internal(msg.split_off(16).into()))
1040        } else {
1041            Err(Error::new(msg.span(), "expected `internal error: $msg`"))
1042        }?;
1043        Ok(MirScalarExpr::Literal(Err(err), bool::as_column_type())) // FIXME
1044    }
1045
1046    fn parse_literal_array(input: ParseStream) -> Result {
1047        use mz_expr::func::VariadicFunc::*;
1048
1049        let elem_type = ScalarType::Int64; // FIXME
1050        let func = ArrayCreate { elem_type };
1051        let exprs = input.parse_comma_sep(parse_literal_ok)?;
1052
1053        // Evaluate into a datum
1054        let temp_storage = RowArena::default();
1055        let datum = func.eval(&[], &temp_storage, &exprs).expect("datum");
1056        let typ = ScalarType::Array(Box::new(ScalarType::Int64)); // FIXME
1057        Ok(MirScalarExpr::literal_ok(datum, typ))
1058    }
1059
1060    fn parse_literal_list(input: ParseStream) -> Result {
1061        use mz_expr::func::VariadicFunc::*;
1062
1063        let elem_type = ScalarType::Int64; // FIXME
1064        let func = ListCreate { elem_type };
1065        let exprs = input.parse_comma_sep(parse_literal_ok)?;
1066
1067        // Evaluate into a datum
1068        let temp_storage = RowArena::default();
1069        let datum = func.eval(&[], &temp_storage, &exprs).expect("datum");
1070        let typ = ScalarType::Array(Box::new(ScalarType::Int64)); // FIXME
1071        Ok(MirScalarExpr::literal_ok(datum, typ))
1072    }
1073
1074    fn parse_array(input: ParseStream) -> Result {
1075        use mz_expr::func::VariadicFunc::*;
1076
1077        input.parse::<kw::array>()?;
1078
1079        // parse brackets
1080        let inner;
1081        syn::bracketed!(inner in input);
1082
1083        let elem_type = ScalarType::Int64; // FIXME
1084        let func = ArrayCreate { elem_type };
1085        let exprs = inner.parse_comma_sep(parse_expr)?;
1086
1087        Ok(MirScalarExpr::CallVariadic { func, exprs })
1088    }
1089
1090    fn parse_list(input: ParseStream) -> Result {
1091        use mz_expr::func::VariadicFunc::*;
1092
1093        input.parse::<kw::list>()?;
1094
1095        // parse brackets
1096        let inner;
1097        syn::bracketed!(inner in input);
1098
1099        let elem_type = ScalarType::Int64; // FIXME
1100        let func = ListCreate { elem_type };
1101        let exprs = inner.parse_comma_sep(parse_expr)?;
1102
1103        Ok(MirScalarExpr::CallVariadic { func, exprs })
1104    }
1105
1106    fn parse_apply(input: ParseStream) -> Result {
1107        let ident = input.parse::<syn::Ident>()?;
1108
1109        // parse parentheses
1110        let inner;
1111        syn::parenthesized!(inner in input);
1112
1113        let parse_nullary = |func: UnmaterializableFunc| -> Result {
1114            Ok(MirScalarExpr::CallUnmaterializable(func))
1115        };
1116        let parse_unary = |func: UnaryFunc| -> Result {
1117            let expr = Box::new(parse_expr(&inner)?);
1118            Ok(MirScalarExpr::CallUnary { func, expr })
1119        };
1120        let parse_binary = |func: BinaryFunc| -> Result {
1121            let expr1 = Box::new(parse_expr(&inner)?);
1122            inner.parse::<syn::Token![,]>()?;
1123            let expr2 = Box::new(parse_expr(&inner)?);
1124            Ok(MirScalarExpr::CallBinary { func, expr1, expr2 })
1125        };
1126        let parse_variadic = |func: VariadicFunc| -> Result {
1127            let exprs = inner.parse_comma_sep(parse_expr)?;
1128            Ok(MirScalarExpr::CallVariadic { func, exprs })
1129        };
1130
1131        // Infix binary and variadic function calls are handled in `parse_scalar_expr`.
1132        //
1133        // Some restrictions apply with the current state of the code,
1134        // most notably one cannot handle overloaded function names because we don't want to do
1135        // name resolution in the parser.
1136        match ident.to_string().to_lowercase().as_str() {
1137            // Supported unmaterializable (a.k.a. nullary) functions:
1138            "mz_environment_id" => parse_nullary(UnmaterializableFunc::MzEnvironmentId),
1139            // Supported unary functions:
1140            "abs" => parse_unary(mz_expr::func::AbsInt64.into()),
1141            "not" => parse_unary(mz_expr::func::Not.into()),
1142            // Supported binary functions:
1143            "ltrim" => parse_binary(BinaryFunc::TrimLeading),
1144            // Supported variadic functions:
1145            "greatest" => parse_variadic(VariadicFunc::Greatest),
1146            _ => Err(Error::new(ident.span(), "unsupported function name")),
1147        }
1148    }
1149
1150    pub fn parse_join_equivalences(input: ParseStream) -> syn::Result<Vec<Vec<MirScalarExpr>>> {
1151        let mut equivalences = vec![];
1152        while !input.is_empty() {
1153            let mut equivalence = vec![];
1154            loop {
1155                let mut worklist = vec![parse_operand(input)?];
1156                while let Some(operand) = worklist.pop() {
1157                    // Be more lenient and support parenthesized equivalences,
1158                    // e.g. `... AND (x = u + v = z + 1) AND ...`.
1159                    if let MirScalarExpr::CallBinary {
1160                        func: BinaryFunc::Eq,
1161                        expr1,
1162                        expr2,
1163                    } = operand
1164                    {
1165                        // We reverse the order in the worklist in order to get
1166                        // the correct order in the equivalence class.
1167                        worklist.push(*expr2);
1168                        worklist.push(*expr1);
1169                    } else {
1170                        equivalence.push(operand);
1171                    }
1172                }
1173                if !input.eat(syn::Token![=]) {
1174                    break;
1175                }
1176            }
1177            equivalences.push(equivalence);
1178            input.eat(kw::AND);
1179        }
1180        Ok(equivalences)
1181    }
1182}
1183
1184/// Support for parsing [mz_expr::AggregateExpr].
1185mod aggregate {
1186    use mz_expr::{AggregateExpr, MirScalarExpr};
1187
1188    use super::*;
1189
1190    type Result = syn::Result<AggregateExpr>;
1191
1192    pub fn parse_expr(input: ParseStream) -> Result {
1193        use mz_expr::AggregateFunc::*;
1194
1195        // Some restrictions apply with the current state of the code,
1196        // most notably one cannot handle overloaded function names because we don't want to do
1197        // name resolution in the parser.
1198        let ident = input.parse::<syn::Ident>()?;
1199        let func = match ident.to_string().to_lowercase().as_str() {
1200            "count" => Count,
1201            "any" => Any,
1202            "all" => All,
1203            "max" => MaxInt64,
1204            "min" => MinInt64,
1205            "sum" => SumInt64,
1206            _ => Err(Error::new(ident.span(), "unsupported function name"))?,
1207        };
1208
1209        // parse parentheses
1210        let inner;
1211        syn::parenthesized!(inner in input);
1212
1213        if func == Count && inner.eat(syn::Token![*]) {
1214            Ok(AggregateExpr {
1215                func,
1216                expr: MirScalarExpr::literal_true(),
1217                distinct: false, // TODO: fix explain output
1218            })
1219        } else {
1220            let distinct = inner.eat(kw::distinct);
1221            let expr = scalar::parse_expr(&inner)?;
1222            Ok(AggregateExpr {
1223                func,
1224                expr,
1225                distinct,
1226            })
1227        }
1228    }
1229}
1230
1231/// Support for parsing [mz_repr::Row].
1232mod row {
1233    use mz_repr::{Datum, Row, RowPacker};
1234
1235    use super::*;
1236
1237    impl Parse for Parsed<Row> {
1238        fn parse(input: ParseStream) -> syn::Result<Self> {
1239            let mut row = Row::default();
1240            let mut packer = ParseRow::new(&mut row);
1241
1242            loop {
1243                if input.is_empty() {
1244                    break;
1245                }
1246                packer.parse_datum(input)?;
1247                if input.is_empty() {
1248                    break;
1249                }
1250                input.parse::<syn::Token![,]>()?;
1251            }
1252
1253            Ok(Parsed(row))
1254        }
1255    }
1256
1257    impl From<Parsed<Row>> for Row {
1258        fn from(parsed: Parsed<Row>) -> Self {
1259            parsed.0
1260        }
1261    }
1262
1263    struct ParseRow<'a>(RowPacker<'a>);
1264
1265    impl<'a> ParseRow<'a> {
1266        fn new(row: &'a mut Row) -> Self {
1267            Self(row.packer())
1268        }
1269
1270        fn parse_datum(&mut self, input: ParseStream) -> syn::Result<()> {
1271            if input.eat(kw::null) {
1272                self.0.push(Datum::Null)
1273            } else {
1274                match input.parse::<syn::Lit>()? {
1275                    syn::Lit::Str(l) => self.0.push(Datum::from(l.value().as_str())),
1276                    syn::Lit::Int(l) => self.0.push(Datum::from(l.base10_parse::<i64>()?)),
1277                    syn::Lit::Float(l) => self.0.push(Datum::from(l.base10_parse::<f64>()?)),
1278                    syn::Lit::Bool(l) => self.0.push(Datum::from(l.value)),
1279                    _ => Err(Error::new(input.span(), "cannot parse literal"))?,
1280                }
1281            }
1282            Ok(())
1283        }
1284    }
1285}
1286
1287mod analyses {
1288    use mz_repr::{ColumnType, ScalarType};
1289
1290    use super::*;
1291
1292    #[derive(Default)]
1293    pub struct Analyses {
1294        pub types: Option<Vec<ColumnType>>,
1295        pub keys: Option<Vec<Vec<usize>>>,
1296    }
1297
1298    pub fn parse_analyses(input: ParseStream) -> syn::Result<Analyses> {
1299        let mut analyses = Analyses::default();
1300
1301        // Analyses are optional, appearing after a `//` at the end of the
1302        // line. However, since the syn lexer eats comments, we assume that `//`
1303        // was replaced with `::` upfront.
1304        if input.eat(syn::Token![::]) {
1305            let inner;
1306            syn::braced!(inner in input);
1307
1308            let (start, end) = (inner.span().start(), inner.span().end());
1309            if start.line != end.line {
1310                let msg = "analyses should not span more than one line".to_string();
1311                Err(Error::new(inner.span(), msg))?
1312            }
1313
1314            while inner.peek(syn::Ident) {
1315                let ident = inner.parse::<syn::Ident>()?.to_string();
1316                match ident.as_str() {
1317                    "types" => {
1318                        inner.parse::<syn::Token![:]>()?;
1319                        let value = inner.parse::<syn::LitStr>()?.value();
1320                        analyses.types = Some(parse_types.parse_str(&value)?);
1321                    }
1322                    // TODO: support keys
1323                    key => {
1324                        let msg = format!("unexpected analysis type `{}`", key);
1325                        Err(Error::new(inner.span(), msg))?;
1326                    }
1327                }
1328            }
1329        }
1330        Ok(analyses)
1331    }
1332
1333    fn parse_types(input: ParseStream) -> syn::Result<Vec<ColumnType>> {
1334        let inner;
1335        syn::parenthesized!(inner in input);
1336        inner.parse_comma_sep(parse_column_type)
1337    }
1338
1339    pub fn parse_column_type(input: ParseStream) -> syn::Result<ColumnType> {
1340        let scalar_type = parse_scalar_type(input)?;
1341        Ok(scalar_type.nullable(input.eat(syn::Token![?])))
1342    }
1343
1344    pub fn parse_scalar_type(input: ParseStream) -> syn::Result<ScalarType> {
1345        let lookahead = input.lookahead1();
1346
1347        let scalar_type = if input.look_and_eat(bigint, &lookahead) {
1348            ScalarType::Int64
1349        } else if input.look_and_eat(double, &lookahead) {
1350            input.parse::<precision>()?;
1351            ScalarType::Float64
1352        } else if input.look_and_eat(boolean, &lookahead) {
1353            ScalarType::Bool
1354        } else if input.look_and_eat(character, &lookahead) {
1355            input.parse::<varying>()?;
1356            ScalarType::VarChar { max_length: None }
1357        } else if input.look_and_eat(integer, &lookahead) {
1358            ScalarType::Int32
1359        } else if input.look_and_eat(smallint, &lookahead) {
1360            ScalarType::Int16
1361        } else if input.look_and_eat(text, &lookahead) {
1362            ScalarType::String
1363        } else {
1364            Err(lookahead.error())?
1365        };
1366
1367        Ok(scalar_type)
1368    }
1369
1370    syn::custom_keyword!(bigint);
1371    syn::custom_keyword!(boolean);
1372    syn::custom_keyword!(character);
1373    syn::custom_keyword!(double);
1374    syn::custom_keyword!(integer);
1375    syn::custom_keyword!(precision);
1376    syn::custom_keyword!(smallint);
1377    syn::custom_keyword!(text);
1378    syn::custom_keyword!(varying);
1379}
1380
1381pub enum Def {
1382    Source {
1383        name: String,
1384        cols: Vec<String>,
1385        typ: mz_repr::RelationType,
1386    },
1387}
1388
1389mod def {
1390    use mz_repr::{ColumnType, RelationType};
1391
1392    use super::*;
1393
1394    pub fn parse_def(ctx: CtxRef, input: ParseStream) -> syn::Result<Def> {
1395        parse_def_source(ctx, input) // only one variant for now
1396    }
1397
1398    fn parse_def_source(ctx: CtxRef, input: ParseStream) -> syn::Result<Def> {
1399        let reduce = input.parse::<def::DefSource>()?;
1400
1401        let name = {
1402            input.parse::<def::name>()?;
1403            input.parse::<syn::Token![=]>()?;
1404            input.parse::<syn::Ident>()?.to_string()
1405        };
1406
1407        let keys = if input.eat(kw::keys) {
1408            input.parse::<syn::Token![=]>()?;
1409            let inner;
1410            syn::bracketed!(inner in input);
1411            inner.parse_comma_sep(|input| {
1412                let inner;
1413                syn::bracketed!(inner in input);
1414                inner.parse_comma_sep(scalar::parse_column_index)
1415            })?
1416        } else {
1417            vec![]
1418        };
1419
1420        let parse_inputs = ParseChildren::new(input, reduce.span().start());
1421        let (cols, column_types) = {
1422            let source_columns = parse_inputs.parse_many(ctx, parse_def_source_column)?;
1423            let mut column_names = vec![];
1424            let mut column_types = vec![];
1425            for (column_name, column_type) in source_columns {
1426                column_names.push(column_name);
1427                column_types.push(column_type);
1428            }
1429            (column_names, column_types)
1430        };
1431
1432        let typ = RelationType { column_types, keys };
1433
1434        Ok(Def::Source { name, cols, typ })
1435    }
1436
1437    fn parse_def_source_column(
1438        _ctx: CtxRef,
1439        input: ParseStream,
1440    ) -> syn::Result<(String, ColumnType)> {
1441        input.parse::<syn::Token![-]>()?;
1442        let column_name = input.parse::<syn::Ident>()?.to_string();
1443        input.parse::<syn::Token![:]>()?;
1444        let column_type = analyses::parse_column_type(input)?;
1445        Ok((column_name, column_type))
1446    }
1447
1448    syn::custom_keyword!(DefSource);
1449    syn::custom_keyword!(name);
1450}
1451
1452/// Help utilities used by sibling modules.
1453mod util {
1454    use syn::parse::{Lookahead1, ParseBuffer, Peek};
1455
1456    use super::*;
1457
1458    /// Extension methods for [`syn::parse::ParseBuffer`].
1459    pub trait ParseBufferExt<'a> {
1460        fn look_and_eat<T: Eat>(&self, token: T, lookahead: &Lookahead1<'a>) -> bool;
1461
1462        /// Consumes a token `T` if present.
1463        fn eat<T: Eat>(&self, t: T) -> bool;
1464
1465        /// Consumes two tokens `T1 T2` if present in that order.
1466        fn eat2<T1: Eat, T2: Eat>(&self, t1: T1, t2: T2) -> bool;
1467
1468        /// Consumes three tokens `T1 T2 T3` if present in that order.
1469        fn eat3<T1: Eat, T2: Eat, T3: Eat>(&self, t1: T1, t2: T2, t3: T3) -> bool;
1470
1471        // Parse a comma-separated list of items into a vector.
1472        fn parse_comma_sep<T>(&self, p: fn(ParseStream) -> syn::Result<T>) -> syn::Result<Vec<T>>;
1473    }
1474
1475    impl<'a> ParseBufferExt<'a> for ParseBuffer<'a> {
1476        /// Consumes a token `T` if present, looking it up using the provided
1477        /// [`Lookahead1`] instance.
1478        fn look_and_eat<T: Eat>(&self, token: T, lookahead: &Lookahead1<'a>) -> bool {
1479            if lookahead.peek(token) {
1480                self.parse::<T::Token>().unwrap();
1481                true
1482            } else {
1483                false
1484            }
1485        }
1486
1487        fn eat<T: Eat>(&self, t: T) -> bool {
1488            if self.peek(t) {
1489                self.parse::<T::Token>().unwrap();
1490                true
1491            } else {
1492                false
1493            }
1494        }
1495
1496        fn eat2<T1: Eat, T2: Eat>(&self, t1: T1, t2: T2) -> bool {
1497            if self.peek(t1) && self.peek2(t2) {
1498                self.parse::<T1::Token>().unwrap();
1499                self.parse::<T2::Token>().unwrap();
1500                true
1501            } else {
1502                false
1503            }
1504        }
1505
1506        fn eat3<T1: Eat, T2: Eat, T3: Eat>(&self, t1: T1, t2: T2, t3: T3) -> bool {
1507            if self.peek(t1) && self.peek2(t2) && self.peek3(t3) {
1508                self.parse::<T1::Token>().unwrap();
1509                self.parse::<T2::Token>().unwrap();
1510                self.parse::<T3::Token>().unwrap();
1511                true
1512            } else {
1513                false
1514            }
1515        }
1516
1517        fn parse_comma_sep<T>(&self, p: fn(ParseStream) -> syn::Result<T>) -> syn::Result<Vec<T>> {
1518            Ok(self
1519                .parse_terminated(p, syn::Token![,])?
1520                .into_iter()
1521                .collect::<Vec<_>>())
1522        }
1523    }
1524
1525    // Helper trait for types that can be eaten.
1526    //
1527    // Implementing types must also implement [`Peek`], and the associated
1528    // [`Peek::Token`] type should implement [`Parse`]). For some reason the
1529    // latter bound is not present in [`Peek`] even if it makes a lot of sense,
1530    // which is why we need this helper.
1531    pub trait Eat: Peek<Token = Self::_Token> {
1532        type _Token: Parse;
1533    }
1534
1535    impl<T> Eat for T
1536    where
1537        T: Peek,
1538        T::Token: Parse,
1539    {
1540        type _Token = T::Token;
1541    }
1542
1543    pub struct Ctx<'a> {
1544        pub catalog: &'a TestCatalog,
1545    }
1546
1547    pub type CtxRef<'a> = &'a Ctx<'a>;
1548
1549    /// Newtype for external types that need to implement [Parse].
1550    pub struct Parsed<T>(pub T);
1551
1552    /// Provides facilities for parsing
1553    pub struct ParseChildren<'a> {
1554        stream: ParseStream<'a>,
1555        parent: LineColumn,
1556    }
1557
1558    impl<'a> ParseChildren<'a> {
1559        pub fn new(stream: ParseStream<'a>, parent: LineColumn) -> Self {
1560            Self { stream, parent }
1561        }
1562
1563        pub fn parse_one<C, T>(
1564            &self,
1565            ctx: C,
1566            function: fn(C, ParseStream) -> syn::Result<T>,
1567        ) -> syn::Result<T> {
1568            match self.maybe_child() {
1569                Ok(_) => function(ctx, self.stream),
1570                Err(e) => Err(e),
1571            }
1572        }
1573
1574        pub fn parse_many<C: Copy, T>(
1575            &self,
1576            ctx: C,
1577            function: fn(C, ParseStream) -> syn::Result<T>,
1578        ) -> syn::Result<Vec<T>> {
1579            let mut inputs = vec![self.parse_one(ctx, function)?];
1580            while self.maybe_child().is_ok() {
1581                inputs.push(function(ctx, self.stream)?);
1582            }
1583            Ok(inputs)
1584        }
1585
1586        fn maybe_child(&self) -> syn::Result<()> {
1587            let start = self.stream.span().start();
1588            if start.line <= self.parent.line {
1589                let msg = format!("child expected at line > {}", self.parent.line);
1590                Err(Error::new(self.stream.span(), msg))?
1591            }
1592            if start.column != self.parent.column + 2 {
1593                let msg = format!("child expected at column {}", self.parent.column + 2);
1594                Err(Error::new(self.stream.span(), msg))?
1595            }
1596            Ok(())
1597        }
1598    }
1599}
1600
1601/// Custom keywords used while parsing.
1602mod kw {
1603    syn::custom_keyword!(aggregates);
1604    syn::custom_keyword!(AND);
1605    syn::custom_keyword!(ArrangeBy);
1606    syn::custom_keyword!(array);
1607    syn::custom_keyword!(asc);
1608    syn::custom_keyword!(Constant);
1609    syn::custom_keyword!(CrossJoin);
1610    syn::custom_keyword!(cte);
1611    syn::custom_keyword!(desc);
1612    syn::custom_keyword!(distinct);
1613    syn::custom_keyword!(Distinct);
1614    syn::custom_keyword!(empty);
1615    syn::custom_keyword!(eq);
1616    syn::custom_keyword!(error);
1617    syn::custom_keyword!(exp_group_size);
1618    syn::custom_keyword!(FALSE);
1619    syn::custom_keyword!(Filter);
1620    syn::custom_keyword!(FlatMap);
1621    syn::custom_keyword!(Get);
1622    syn::custom_keyword!(group_by);
1623    syn::custom_keyword!(IS);
1624    syn::custom_keyword!(Join);
1625    syn::custom_keyword!(keys);
1626    syn::custom_keyword!(limit);
1627    syn::custom_keyword!(list);
1628    syn::custom_keyword!(Map);
1629    syn::custom_keyword!(monotonic);
1630    syn::custom_keyword!(Mutually);
1631    syn::custom_keyword!(Negate);
1632    syn::custom_keyword!(NOT);
1633    syn::custom_keyword!(null);
1634    syn::custom_keyword!(NULL);
1635    syn::custom_keyword!(nulls_first);
1636    syn::custom_keyword!(nulls_last);
1637    syn::custom_keyword!(offset);
1638    syn::custom_keyword!(on);
1639    syn::custom_keyword!(OR);
1640    syn::custom_keyword!(order_by);
1641    syn::custom_keyword!(project);
1642    syn::custom_keyword!(Project);
1643    syn::custom_keyword!(Recursive);
1644    syn::custom_keyword!(Reduce);
1645    syn::custom_keyword!(Return);
1646    syn::custom_keyword!(Threshold);
1647    syn::custom_keyword!(TopK);
1648    syn::custom_keyword!(TRUE);
1649    syn::custom_keyword!(Union);
1650    syn::custom_keyword!(With);
1651    syn::custom_keyword!(x);
1652}