Skip to main content

mz_expr_parser/
parser.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Copyright Materialize, Inc. and contributors. All rights reserved.
11//
12// Use of this software is governed by the Business Source License
13// included in the LICENSE file.
14//
15// As of the Change Date specified in that file, in accordance with
16// the Business Source License, use of this software will be governed
17// by the Apache License, Version 2.0.
18
19use mz_ore::collections::CollectionExt;
20use proc_macro2::LineColumn;
21use syn::Error;
22use syn::parse::{Parse, ParseStream, Parser};
23use syn::spanned::Spanned;
24
25use super::TestCatalog;
26
27use self::util::*;
28
29/// Builds a [mz_expr::MirRelationExpr] from a string.
30pub fn try_parse_mir(catalog: &TestCatalog, s: &str) -> Result<mz_expr::MirRelationExpr, String> {
31    // Define a Parser that constructs a (read-only) parsing context `ctx` and
32    // delegates to `relation::parse_expr` by passing a `ctx` as a shared ref.
33    let parser = move |input: ParseStream| {
34        let ctx = Ctx { catalog };
35        relation::parse_expr(&ctx, input)
36    };
37    // Since the syn lexer doesn't parse comments, we replace all `// {`
38    // occurrences in the input string with `:: {`.
39    let s = s.replace("// {", ":: {");
40    // Call the parser with the given input string.
41    let mut expr = parser.parse_str(&s).map_err(|err| {
42        let (line, column) = (err.span().start().line, err.span().start().column);
43        format!("parse error at {line}:{column}:\n{err}\n")
44    })?;
45    // Fix the types of the local let bindings of the parsed expression in a
46    // post-processing pass.
47    relation::fix_types(&mut expr, &mut relation::FixTypesCtx::default())?;
48    // Return the parsed, post-processed expression.
49    Ok(expr)
50}
51
52/// Builds a source definition from a string.
53pub fn try_parse_def(catalog: &TestCatalog, s: &str) -> Result<Def, String> {
54    // Define a Parser that constructs a (read-only) parsing context `ctx` and
55    // delegates to `relation::parse_expr` by passing a `ctx` as a shared ref.
56    let parser = move |input: ParseStream| {
57        let ctx = Ctx { catalog };
58        def::parse_def(&ctx, input)
59    };
60    // Call the parser with the given input string.
61    let def = parser.parse_str(s).map_err(|err| {
62        let (line, column) = (err.span().start().line, err.span().start().column);
63        format!("parse error at {line}:{column}:\n{err}\n")
64    })?;
65    // Return the parsed, post-processed expression.
66    Ok(def)
67}
68
69/// Support for parsing [mz_expr::MirRelationExpr].
70mod relation {
71    use std::collections::BTreeMap;
72
73    use mz_expr::{AccessStrategy, Id, JoinImplementation, LocalId, MirRelationExpr};
74    use mz_repr::{Diff, ReprRelationType, Row, SqlScalarType};
75
76    use crate::parser::analyses::Analyses;
77
78    use super::*;
79
80    type Result = syn::Result<MirRelationExpr>;
81
82    pub fn parse_expr(ctx: CtxRef, input: ParseStream) -> Result {
83        let lookahead = input.lookahead1();
84        if lookahead.peek(kw::Constant) {
85            parse_constant(ctx, input)
86        } else if lookahead.peek(kw::Get) {
87            parse_get(ctx, input)
88        } else if lookahead.peek(kw::Return) {
89            parse_let_or_letrec_old(ctx, input)
90        } else if lookahead.peek(kw::With) {
91            parse_let_or_letrec(ctx, input)
92        } else if lookahead.peek(kw::Project) {
93            parse_project(ctx, input)
94        } else if lookahead.peek(kw::Map) {
95            parse_map(ctx, input)
96        } else if lookahead.peek(kw::FlatMap) {
97            parse_flat_map(ctx, input)
98        } else if lookahead.peek(kw::Filter) {
99            parse_filter(ctx, input)
100        } else if lookahead.peek(kw::CrossJoin) {
101            parse_cross_join(ctx, input)
102        } else if lookahead.peek(kw::Join) {
103            parse_join(ctx, input)
104        } else if lookahead.peek(kw::Distinct) {
105            parse_distinct(ctx, input)
106        } else if lookahead.peek(kw::Reduce) {
107            parse_reduce(ctx, input)
108        } else if lookahead.peek(kw::TopK) {
109            parse_top_k(ctx, input)
110        } else if lookahead.peek(kw::Negate) {
111            parse_negate(ctx, input)
112        } else if lookahead.peek(kw::Threshold) {
113            parse_threshold(ctx, input)
114        } else if lookahead.peek(kw::Union) {
115            parse_union(ctx, input)
116        } else if lookahead.peek(kw::ArrangeBy) {
117            parse_arrange_by(ctx, input)
118        } else {
119            Err(lookahead.error())
120        }
121    }
122
123    fn parse_constant(ctx: CtxRef, input: ParseStream) -> Result {
124        let constant = input.parse::<kw::Constant>()?;
125
126        let parse_typ = |input: ParseStream| -> syn::Result<ReprRelationType> {
127            let analyses = analyses::parse_analyses(input)?;
128            let Some(column_types) = analyses.types else {
129                let msg = "Missing expected `types` analyses for Constant line";
130                Err(Error::new(input.span(), msg))?
131            };
132            let keys = analyses.keys.unwrap_or_default();
133            Ok(ReprRelationType::new(column_types).with_keys(keys))
134        };
135        if input.eat3(syn::Token![<], kw::empty, syn::Token![>]) {
136            let typ = parse_typ(input)?;
137            Ok(MirRelationExpr::Constant {
138                rows: Ok(vec![]),
139                typ,
140            })
141        } else {
142            let typ = parse_typ(input)?;
143            let parse_children = ParseChildren::new(input, constant.span().start());
144            let rows = Ok(parse_children.parse_many(ctx, parse_constant_entry)?);
145            Ok(MirRelationExpr::Constant { rows, typ })
146        }
147    }
148
149    fn parse_constant_entry(_ctx: CtxRef, input: ParseStream) -> syn::Result<(Row, Diff)> {
150        input.parse::<syn::Token![-]>()?;
151
152        let (row, diff);
153
154        let inner1;
155        syn::parenthesized!(inner1 in input);
156
157        if inner1.peek(syn::token::Paren) {
158            let inner2;
159            syn::parenthesized!(inner2 in inner1);
160            row = inner2.parse::<Parsed<Row>>()?.into();
161            inner1.parse::<kw::x>()?;
162            diff = match inner1.parse::<syn::Lit>()? {
163                syn::Lit::Int(l) => Ok(l.base10_parse::<Diff>()?),
164                _ => Err(Error::new(inner1.span(), "expected Diff literal")),
165            }?;
166        } else {
167            row = inner1.parse::<Parsed<Row>>()?.into();
168            diff = Diff::ONE;
169        }
170
171        Ok((row, diff))
172    }
173
174    fn parse_get(ctx: CtxRef, input: ParseStream) -> Result {
175        input.parse::<kw::Get>()?;
176
177        let ident = input.parse::<syn::Ident>()?;
178        match ctx.catalog.get(&ident.to_string()) {
179            Some((id, _cols, typ)) => Ok(MirRelationExpr::Get {
180                id: Id::Global(*id),
181                typ: ReprRelationType::from(typ),
182                access_strategy: AccessStrategy::UnknownOrLocal,
183            }),
184            None => Ok(MirRelationExpr::Get {
185                id: Id::Local(parse_local_id(ident)?),
186                typ: ReprRelationType::empty(),
187                access_strategy: AccessStrategy::UnknownOrLocal,
188            }),
189        }
190    }
191
192    /// Parses a Let or a LetRec with the old order: Return first, and then CTEs in descending order.
193    fn parse_let_or_letrec_old(ctx: CtxRef, input: ParseStream) -> Result {
194        let return_ = input.parse::<kw::Return>()?;
195        let parse_body = ParseChildren::new(input, return_.span().start());
196        let body = parse_body.parse_one(ctx, parse_expr)?;
197
198        let with = input.parse::<kw::With>()?;
199        let recursive = input.eat2(kw::Mutually, kw::Recursive);
200        let parse_ctes = ParseChildren::new(input, with.span().start());
201        let mut ctes = parse_ctes.parse_many(ctx, parse_cte)?;
202
203        if ctes.is_empty() {
204            let msg = "At least one Let/LetRec cte binding expected";
205            Err(Error::new(input.span(), msg))?
206        }
207
208        ctes.reverse();
209        let cte_ids = ctes.iter().map(|(id, _, _)| id);
210        if !cte_ids.clone().is_sorted() {
211            let msg = format!(
212                "Error parsing Let/LetRec: seen Return before With, but cte ids are not ordered descending: {:?}",
213                cte_ids.collect::<Vec<_>>()
214            );
215            Err(Error::new(input.span(), msg))?
216        }
217        build_let_or_let_rec(ctes, body, recursive, with)
218    }
219
220    /// Parses a Let or a LetRec with the new order: CTEs first in ascending order, and then Return.
221    fn parse_let_or_letrec(ctx: CtxRef, input: ParseStream) -> Result {
222        let with = input.parse::<kw::With>()?;
223        let recursive = input.eat2(kw::Mutually, kw::Recursive);
224        let parse_ctes = ParseChildren::new(input, with.span().start());
225        let ctes = parse_ctes.parse_many(ctx, parse_cte)?;
226
227        let return_ = input.parse::<kw::Return>()?;
228        let parse_body = ParseChildren::new(input, return_.span().start());
229        let body = parse_body.parse_one(ctx, parse_expr)?;
230
231        if ctes.is_empty() {
232            let msg = "At least one `let cte` binding expected";
233            Err(Error::new(input.span(), msg))?
234        }
235
236        let cte_ids = ctes.iter().map(|(id, _, _)| id);
237        if !cte_ids.clone().is_sorted() {
238            let msg = format!(
239                "Error parsing Let/LetRec: seen With before Return, but cte ids are not ordered ascending: {:?}",
240                cte_ids.collect::<Vec<_>>()
241            );
242            Err(Error::new(input.span(), msg))?
243        }
244        build_let_or_let_rec(ctes, body, recursive, with)
245    }
246
247    fn build_let_or_let_rec(
248        ctes: Vec<(LocalId, Analyses, MirRelationExpr)>,
249        body: MirRelationExpr,
250        recursive: bool,
251        with: kw::With,
252    ) -> Result {
253        if recursive {
254            let (mut ids, mut values, mut limits) = (vec![], vec![], vec![]);
255            for (id, analyses, value) in ctes.into_iter() {
256                let typ = {
257                    let Some(column_types) = analyses.types else {
258                        let msg = format!("`let {}` needs a `types` analyses", id);
259                        Err(Error::new(with.span(), msg))?
260                    };
261                    let keys = analyses.keys.unwrap_or_default();
262                    ReprRelationType::new(column_types).with_keys(keys)
263                };
264                // An ugly-ugly hack to pass the type information of the WMR CTE
265                // to the `fix_types` pass.
266                let value = {
267                    let get_cte = MirRelationExpr::Get {
268                        id: Id::Local(id),
269                        typ,
270                        access_strategy: AccessStrategy::UnknownOrLocal,
271                    };
272                    // Do not use the `union` smart constructor here!
273                    MirRelationExpr::Union {
274                        base: Box::new(get_cte),
275                        inputs: vec![value],
276                    }
277                };
278
279                ids.push(id);
280                values.push(value);
281                limits.push(None); // TODO: support limits
282            }
283
284            Ok(MirRelationExpr::LetRec {
285                ids,
286                values,
287                limits,
288                body: Box::new(body),
289            })
290        } else {
291            let mut body = body;
292            for (id, _, value) in ctes.into_iter().rev() {
293                body = MirRelationExpr::Let {
294                    id,
295                    value: Box::new(value),
296                    body: Box::new(body),
297                };
298            }
299            Ok(body)
300        }
301    }
302
303    fn parse_cte(
304        ctx: CtxRef,
305        input: ParseStream,
306    ) -> syn::Result<(LocalId, analyses::Analyses, MirRelationExpr)> {
307        let cte = input.parse::<kw::cte>()?;
308
309        let ident = input.parse::<syn::Ident>()?;
310        let id = parse_local_id(ident)?;
311
312        input.parse::<syn::Token![=]>()?;
313
314        let analyses = analyses::parse_analyses(input)?;
315
316        let parse_value = ParseChildren::new(input, cte.span().start());
317        let value = parse_value.parse_one(ctx, parse_expr)?;
318
319        Ok((id, analyses, value))
320    }
321
322    fn parse_project(ctx: CtxRef, input: ParseStream) -> Result {
323        let project = input.parse::<kw::Project>()?;
324
325        let content;
326        syn::parenthesized!(content in input);
327        let outputs = content.parse_comma_sep(scalar::parse_column_index)?;
328        let parse_input = ParseChildren::new(input, project.span().start());
329        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
330
331        Ok(MirRelationExpr::Project { input, outputs })
332    }
333
334    fn parse_map(ctx: CtxRef, input: ParseStream) -> Result {
335        let map = input.parse::<kw::Map>()?;
336
337        let scalars = {
338            let inner;
339            syn::parenthesized!(inner in input);
340            scalar::parse_exprs(&inner)?
341        };
342
343        let parse_input = ParseChildren::new(input, map.span().start());
344        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
345
346        Ok(MirRelationExpr::Map { input, scalars })
347    }
348
349    fn parse_flat_map(ctx: CtxRef, input: ParseStream) -> Result {
350        use mz_expr::TableFunc::*;
351
352        let flat_map = input.parse::<kw::FlatMap>()?;
353
354        let ident = input.parse::<syn::Ident>()?;
355        let func = match ident.to_string().to_lowercase().as_str() {
356            "unnest_list" => UnnestList {
357                el_typ: SqlScalarType::Int64, // FIXME
358            },
359            "unnest_array" => UnnestArray {
360                el_typ: SqlScalarType::Int64, // FIXME
361            },
362            "wrap1" => Wrap {
363                types: vec![
364                    SqlScalarType::Int64.nullable(true), // FIXME
365                ],
366                width: 1,
367            },
368            "wrap2" => Wrap {
369                types: vec![
370                    SqlScalarType::Int64.nullable(true), // FIXME
371                    SqlScalarType::Int64.nullable(true), // FIXME
372                ],
373                width: 2,
374            },
375            "wrap3" => Wrap {
376                types: vec![
377                    SqlScalarType::Int64.nullable(true), // FIXME
378                    SqlScalarType::Int64.nullable(true), // FIXME
379                    SqlScalarType::Int64.nullable(true), // FIXME
380                ],
381                width: 3,
382            },
383            "generate_series" => GenerateSeriesInt64,
384            "jsonb_object_keys" => JsonbObjectKeys,
385            _ => Err(Error::new(ident.span(), "unsupported function name"))?,
386        };
387
388        let exprs = {
389            let inner;
390            syn::parenthesized!(inner in input);
391            scalar::parse_exprs(&inner)?
392        };
393
394        let parse_input = ParseChildren::new(input, flat_map.span().start());
395        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
396
397        Ok(MirRelationExpr::FlatMap { input, func, exprs })
398    }
399
400    fn parse_filter(ctx: CtxRef, input: ParseStream) -> Result {
401        use mz_expr::MirScalarExpr::CallVariadic;
402        use mz_expr::VariadicFunc::And;
403
404        let filter = input.parse::<kw::Filter>()?;
405
406        let predicates = match scalar::parse_expr(input)? {
407            CallVariadic {
408                func: And(_),
409                exprs,
410            } => exprs,
411            expr => vec![expr],
412        };
413
414        let parse_input = ParseChildren::new(input, filter.span().start());
415        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
416
417        Ok(MirRelationExpr::Filter { input, predicates })
418    }
419
420    fn parse_cross_join(ctx: CtxRef, input: ParseStream) -> Result {
421        let join = input.parse::<kw::CrossJoin>()?;
422
423        let parse_inputs = ParseChildren::new(input, join.span().start());
424        let inputs = parse_inputs.parse_many(ctx, parse_expr)?;
425
426        Ok(MirRelationExpr::Join {
427            inputs,
428            equivalences: vec![],
429            implementation: JoinImplementation::Unimplemented,
430        })
431    }
432
433    fn parse_join(ctx: CtxRef, input: ParseStream) -> Result {
434        let join = input.parse::<kw::Join>()?;
435
436        input.parse::<kw::on>()?;
437        input.parse::<syn::Token![=]>()?;
438        let inner;
439        syn::parenthesized!(inner in input);
440        let equivalences = scalar::parse_join_equivalences(&inner)?;
441
442        let parse_inputs = ParseChildren::new(input, join.span().start());
443        let inputs = parse_inputs.parse_many(ctx, parse_expr)?;
444
445        Ok(MirRelationExpr::Join {
446            inputs,
447            equivalences,
448            implementation: JoinImplementation::Unimplemented,
449        })
450    }
451
452    fn parse_distinct(ctx: CtxRef, input: ParseStream) -> Result {
453        let reduce = input.parse::<kw::Distinct>()?;
454
455        let group_key = if input.eat(kw::project) {
456            input.parse::<syn::Token![=]>()?;
457            let inner;
458            syn::bracketed!(inner in input);
459            inner.parse_comma_sep(scalar::parse_expr)?
460        } else {
461            vec![]
462        };
463
464        let monotonic = input.eat(kw::monotonic);
465
466        let expected_group_size = if input.eat(kw::exp_group_size) {
467            input.parse::<syn::Token![=]>()?;
468            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
469        } else {
470            None
471        };
472
473        let parse_inputs = ParseChildren::new(input, reduce.span().start());
474        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
475
476        Ok(MirRelationExpr::Reduce {
477            input,
478            group_key,
479            aggregates: vec![],
480            monotonic,
481            expected_group_size,
482        })
483    }
484
485    fn parse_reduce(ctx: CtxRef, input: ParseStream) -> Result {
486        let reduce = input.parse::<kw::Reduce>()?;
487
488        let group_key = if input.eat(kw::group_by) {
489            input.parse::<syn::Token![=]>()?;
490            let inner;
491            syn::bracketed!(inner in input);
492            inner.parse_comma_sep(scalar::parse_expr)?
493        } else {
494            vec![]
495        };
496
497        let aggregates = {
498            input.parse::<kw::aggregates>()?;
499            input.parse::<syn::Token![=]>()?;
500            let inner;
501            syn::bracketed!(inner in input);
502            inner.parse_comma_sep(aggregate::parse_expr)?
503        };
504
505        let monotonic = input.eat(kw::monotonic);
506
507        let expected_group_size = if input.eat(kw::exp_group_size) {
508            input.parse::<syn::Token![=]>()?;
509            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
510        } else {
511            None
512        };
513
514        let parse_inputs = ParseChildren::new(input, reduce.span().start());
515        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
516
517        Ok(MirRelationExpr::Reduce {
518            input,
519            group_key,
520            aggregates,
521            monotonic,
522            expected_group_size,
523        })
524    }
525
526    fn parse_top_k(ctx: CtxRef, input: ParseStream) -> Result {
527        let top_k = input.parse::<kw::TopK>()?;
528
529        let group_key = if input.eat(kw::group_by) {
530            input.parse::<syn::Token![=]>()?;
531            let inner;
532            syn::bracketed!(inner in input);
533            inner.parse_comma_sep(scalar::parse_column_index)?
534        } else {
535            vec![]
536        };
537
538        let order_key = if input.eat(kw::order_by) {
539            input.parse::<syn::Token![=]>()?;
540            let inner;
541            syn::bracketed!(inner in input);
542            inner.parse_comma_sep(scalar::parse_column_order)?
543        } else {
544            vec![]
545        };
546
547        let limit = if input.eat(kw::limit) {
548            input.parse::<syn::Token![=]>()?;
549            Some(scalar::parse_expr(input)?)
550        } else {
551            None
552        };
553
554        let offset = if input.eat(kw::offset) {
555            input.parse::<syn::Token![=]>()?;
556            input.parse::<syn::LitInt>()?.base10_parse::<usize>()?
557        } else {
558            0
559        };
560
561        let monotonic = input.eat(kw::monotonic);
562
563        let expected_group_size = if input.eat(kw::exp_group_size) {
564            input.parse::<syn::Token![=]>()?;
565            Some(input.parse::<syn::LitInt>()?.base10_parse::<u64>()?)
566        } else {
567            None
568        };
569
570        let parse_inputs = ParseChildren::new(input, top_k.span().start());
571        let input = Box::new(parse_inputs.parse_one(ctx, parse_expr)?);
572
573        Ok(MirRelationExpr::TopK {
574            input,
575            group_key,
576            order_key,
577            limit,
578            offset,
579            monotonic,
580            expected_group_size,
581        })
582    }
583
584    fn parse_negate(ctx: CtxRef, input: ParseStream) -> Result {
585        let negate = input.parse::<kw::Negate>()?;
586
587        let parse_input = ParseChildren::new(input, negate.span().start());
588        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
589
590        Ok(MirRelationExpr::Negate { input })
591    }
592
593    fn parse_threshold(ctx: CtxRef, input: ParseStream) -> Result {
594        let threshold = input.parse::<kw::Threshold>()?;
595
596        let parse_input = ParseChildren::new(input, threshold.span().start());
597        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
598
599        Ok(MirRelationExpr::Threshold { input })
600    }
601
602    fn parse_union(ctx: CtxRef, input: ParseStream) -> Result {
603        let union = input.parse::<kw::Union>()?;
604
605        let parse_inputs = ParseChildren::new(input, union.span().start());
606        let mut children = parse_inputs.parse_many(ctx, parse_expr)?;
607        let inputs = children.split_off(1);
608        let base = Box::new(children.into_element());
609
610        Ok(MirRelationExpr::Union { base, inputs })
611    }
612
613    fn parse_arrange_by(ctx: CtxRef, input: ParseStream) -> Result {
614        let arrange_by = input.parse::<kw::ArrangeBy>()?;
615
616        let keys = {
617            input.parse::<kw::keys>()?;
618            input.parse::<syn::Token![=]>()?;
619            let inner;
620            syn::bracketed!(inner in input);
621            inner.parse_comma_sep(|input| {
622                let inner;
623                syn::bracketed!(inner in input);
624                scalar::parse_exprs(&inner)
625            })?
626        };
627
628        let parse_input = ParseChildren::new(input, arrange_by.span().start());
629        let input = Box::new(parse_input.parse_one(ctx, parse_expr)?);
630
631        Ok(MirRelationExpr::ArrangeBy { input, keys })
632    }
633
634    fn parse_local_id(ident: syn::Ident) -> syn::Result<LocalId> {
635        if ident.to_string().starts_with('l') {
636            let n = ident.to_string()[1..]
637                .parse::<u64>()
638                .map_err(|err| Error::new(ident.span(), err.to_string()))?;
639            Ok(mz_expr::LocalId::new(n))
640        } else {
641            Err(Error::new(ident.span(), "invalid LocalId"))
642        }
643    }
644
645    #[derive(Default)]
646    pub struct FixTypesCtx {
647        env: BTreeMap<LocalId, ReprRelationType>,
648        typ: Vec<ReprRelationType>,
649    }
650
651    pub fn fix_types(
652        expr: &mut MirRelationExpr,
653        ctx: &mut FixTypesCtx,
654    ) -> std::result::Result<(), String> {
655        match expr {
656            MirRelationExpr::Let { id, value, body } => {
657                fix_types(value, ctx)?;
658                let value_typ = ctx.typ.pop().expect("value type");
659                let prior_typ = ctx.env.insert(id.clone(), value_typ);
660                fix_types(body, ctx)?;
661                ctx.env.remove(id);
662                if let Some(prior_typ) = prior_typ {
663                    ctx.env.insert(id.clone(), prior_typ);
664                }
665            }
666            MirRelationExpr::LetRec {
667                ids,
668                values,
669                body,
670                limits: _,
671            } => {
672                // An ugly-ugly hack to pass the type information of the WMR CTE
673                // to the `fix_types` pass.
674                let mut prior_typs = BTreeMap::default();
675                for (id, value) in std::iter::zip(ids.iter_mut(), values.iter_mut()) {
676                    let MirRelationExpr::Union { base, mut inputs } = value.take_dangerous() else {
677                        unreachable!("ensured by construction");
678                    };
679                    let MirRelationExpr::Get { id: _, typ, .. } = *base else {
680                        unreachable!("ensured by construction");
681                    };
682                    if let Some(prior_typ) = ctx.env.insert(id.clone(), typ) {
683                        prior_typs.insert(id.clone(), prior_typ);
684                    }
685                    *value = inputs.pop().expect("ensured by construction");
686                }
687                for value in values.iter_mut() {
688                    fix_types(value, ctx)?;
689                }
690                fix_types(body, ctx)?;
691                for id in ids.iter() {
692                    ctx.env.remove(id);
693                    if let Some(prior_typ) = prior_typs.remove(id) {
694                        ctx.env.insert(id.clone(), prior_typ);
695                    }
696                }
697            }
698            MirRelationExpr::Get {
699                id: Id::Local(id),
700                typ,
701                ..
702            } => {
703                let env_typ = match ctx.env.get(&*id) {
704                    Some(env_typ) => env_typ,
705                    None => Err(format!("Cannot fix type of unbound CTE {}", id))?,
706                };
707                *typ = env_typ.clone();
708                ctx.typ.push(env_typ.clone());
709            }
710            _ => {
711                for input in expr.children_mut() {
712                    fix_types(input, ctx)?;
713                }
714                let input_types = ctx.typ.split_off(ctx.typ.len() - expr.num_inputs());
715                ctx.typ.push(expr.typ_with_input_types(&input_types));
716            }
717        };
718
719        Ok(())
720    }
721}
722
723/// Support for parsing [mz_expr::MirScalarExpr].
724mod scalar {
725    use mz_expr::{
726        BinaryFunc, ColumnOrder, MirScalarExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, func,
727    };
728    use mz_repr::{
729        AsColumnType, Datum, ReprColumnType, ReprScalarType, Row, RowArena, SqlScalarType,
730    };
731
732    use super::*;
733
734    type Result = syn::Result<MirScalarExpr>;
735
736    pub fn parse_exprs(input: ParseStream) -> syn::Result<Vec<MirScalarExpr>> {
737        input.parse_comma_sep(parse_expr)
738    }
739
740    /// Parses a single expression.
741    ///
742    /// Because in EXPLAIN contexts parentheses might be optional, we need to
743    /// correctly handle operator precedence of infix operators.
744    ///
745    /// Currently, this works in two steps:
746    ///
747    /// 1. Convert the original infix expression to a postfix expression using
748    ///    an adapted variant of this [algorithm] with precedence taken from the
749    ///    Postgres [precedence] docs. Parenthesized operands are parsed in one
750    ///    step, so steps (3-4) from the [algorithm] are not needed here.
751    /// 2. Convert the postfix vector into a single [MirScalarExpr].
752    ///
753    /// [algorithm]: <https://www.prepbytes.com/blog/stacks/infix-to-postfix-conversion-using-stack/>
754    /// [precedence]: <https://www.postgresql.org/docs/7.2/sql-precedence.html>
755    pub fn parse_expr(input: ParseStream) -> Result {
756        let line = input.span().start().line;
757
758        /// Helper struct to keep track of the parsing state.
759        #[derive(Debug)]
760        enum Op {
761            Unr(mz_expr::UnaryFunc), // unary
762            Neg(mz_expr::UnaryFunc), // negated unary (append -.not() on fold)
763            Bin(mz_expr::BinaryFunc),
764            Var(mz_expr::VariadicFunc),
765        }
766
767        impl Op {
768            fn precedence(&self) -> Option<usize> {
769                match self {
770                    // 01: logical disjunction
771                    Op::Var(mz_expr::VariadicFunc::Or(_)) => Some(1),
772                    // 02: logical conjunction
773                    Op::Var(mz_expr::VariadicFunc::And(_)) => Some(2),
774                    // 04: equality, assignment
775                    Op::Bin(mz_expr::BinaryFunc::Eq(_)) => Some(4),
776                    Op::Bin(mz_expr::BinaryFunc::NotEq(_)) => Some(4),
777                    // 05: less than, greater than
778                    Op::Bin(mz_expr::BinaryFunc::Gt(_)) => Some(5),
779                    Op::Bin(mz_expr::BinaryFunc::Gte(_)) => Some(5),
780                    Op::Bin(mz_expr::BinaryFunc::Lt(_)) => Some(5),
781                    Op::Bin(mz_expr::BinaryFunc::Lte(_)) => Some(5),
782                    // 13: test for TRUE, FALSE, UNKNOWN, NULL
783                    Op::Unr(mz_expr::UnaryFunc::IsNull(_)) => Some(13),
784                    Op::Neg(mz_expr::UnaryFunc::IsNull(_)) => Some(13),
785                    Op::Unr(mz_expr::UnaryFunc::IsTrue(_)) => Some(13),
786                    Op::Neg(mz_expr::UnaryFunc::IsTrue(_)) => Some(13),
787                    Op::Unr(mz_expr::UnaryFunc::IsFalse(_)) => Some(13),
788                    Op::Neg(mz_expr::UnaryFunc::IsFalse(_)) => Some(13),
789                    // 14: addition, subtraction
790                    Op::Bin(mz_expr::BinaryFunc::AddInt64(_)) => Some(14),
791                    // 14: multiplication, division, modulo
792                    Op::Bin(mz_expr::BinaryFunc::MulInt64(_)) => Some(15),
793                    Op::Bin(mz_expr::BinaryFunc::DivInt64(_)) => Some(15),
794                    Op::Bin(mz_expr::BinaryFunc::ModInt64(_)) => Some(15),
795                    // unsupported
796                    _ => None,
797                }
798            }
799        }
800
801        /// Helper struct for entries in the postfix vector.
802        #[derive(Debug)]
803        enum Entry {
804            Operand(MirScalarExpr),
805            Operator(Op),
806        }
807
808        let mut opstack = vec![];
809        let mut postfix = vec![];
810        let mut exp_opd = true; // expects an argument of an operator
811
812        // Scan the given infix expression from left to right.
813        while !input.is_empty() && input.span().start().line == line {
814            // Operands and operators alternate.
815            if exp_opd {
816                postfix.push(Entry::Operand(parse_operand(input)?));
817                exp_opd = false;
818            } else {
819                // If the current symbol is an operator, then bind it to op.
820                // Else it is an operand - append it to postfix and continue.
821                let op = if input.eat(syn::Token![=]) {
822                    exp_opd = true;
823                    Op::Bin(func::Eq.into())
824                } else if input.eat(syn::Token![!=]) {
825                    exp_opd = true;
826                    Op::Bin(func::NotEq.into())
827                } else if input.eat(syn::Token![>=]) {
828                    exp_opd = true;
829                    Op::Bin(func::Gte.into())
830                } else if input.eat(syn::Token![>]) {
831                    exp_opd = true;
832                    Op::Bin(func::Gt.into())
833                } else if input.eat(syn::Token![<=]) {
834                    exp_opd = true;
835                    Op::Bin(func::Lte.into())
836                } else if input.eat(syn::Token![<]) {
837                    exp_opd = true;
838                    Op::Bin(func::Lt.into())
839                } else if input.eat(syn::Token![+]) {
840                    exp_opd = true;
841                    Op::Bin(func::AddInt64.into()) // TODO: fix placeholder
842                } else if input.eat(syn::Token![*]) {
843                    exp_opd = true;
844                    Op::Bin(func::MulInt64.into()) // TODO: fix placeholder
845                } else if input.eat(syn::Token![/]) {
846                    exp_opd = true;
847                    Op::Bin(func::DivInt64.into()) // TODO: fix placeholder
848                } else if input.eat(syn::Token![%]) {
849                    exp_opd = true;
850                    Op::Bin(func::ModInt64.into()) // TODO: fix placeholder
851                } else if input.eat(kw::AND) {
852                    exp_opd = true;
853                    Op::Var(VariadicFunc::And(func::variadic::And))
854                } else if input.eat(kw::OR) {
855                    exp_opd = true;
856                    Op::Var(VariadicFunc::Or(func::variadic::Or))
857                } else if input.eat(kw::IS) {
858                    let negate = input.eat(kw::NOT);
859
860                    let lookahead = input.lookahead1();
861                    let func = if input.look_and_eat(kw::NULL, &lookahead) {
862                        mz_expr::func::IsNull.into()
863                    } else if input.look_and_eat(kw::TRUE, &lookahead) {
864                        mz_expr::func::IsTrue.into()
865                    } else if input.look_and_eat(kw::FALSE, &lookahead) {
866                        mz_expr::func::IsFalse.into()
867                    } else {
868                        Err(lookahead.error())?
869                    };
870
871                    if negate { Op::Neg(func) } else { Op::Unr(func) }
872                } else {
873                    // We were expecting an optional operator but didn't find
874                    // anything. Exit the parsing loop and process the postfix
875                    // vector.
876                    break;
877                };
878
879                // First, pop the operators which are already on the opstack that
880                // have higher or equal precedence than the current operator and
881                // append them to the postfix.
882                while opstack
883                    .last()
884                    .map(|op1: &Op| op1.precedence() >= op.precedence())
885                    .unwrap_or(false)
886                {
887                    let op1 = opstack.pop().expect("non-empty opstack");
888                    postfix.push(Entry::Operator(op1));
889                }
890
891                // Then push the op from this iteration onto the stack.
892                opstack.push(op);
893            }
894        }
895
896        // Pop all remaining symbols from opstack and append them to postfix.
897        postfix.extend(opstack.into_iter().rev().map(Entry::Operator));
898
899        if postfix.is_empty() {
900            let msg = "Cannot parse an empty expression";
901            Err(Error::new(input.span(), msg))?
902        }
903
904        // Flatten the postfix vector into a single MirScalarExpr.
905        let mut stack = vec![];
906        postfix.reverse();
907        while let Some(entry) = postfix.pop() {
908            match entry {
909                Entry::Operand(expr) => {
910                    stack.push(expr);
911                }
912                Entry::Operator(Op::Unr(func)) => {
913                    let expr = Box::new(stack.pop().expect("non-empty stack"));
914                    stack.push(MirScalarExpr::CallUnary { func, expr });
915                }
916                Entry::Operator(Op::Neg(func)) => {
917                    let expr = Box::new(stack.pop().expect("non-empty stack"));
918                    stack.push(MirScalarExpr::CallUnary { func, expr }.not());
919                }
920                Entry::Operator(Op::Bin(func)) => {
921                    let expr2 = Box::new(stack.pop().expect("non-empty stack"));
922                    let expr1 = Box::new(stack.pop().expect("non-empty stack"));
923                    stack.push(MirScalarExpr::CallBinary { func, expr1, expr2 });
924                }
925                Entry::Operator(Op::Var(func)) => {
926                    let expr2 = stack.pop().expect("non-empty stack");
927                    let expr1 = stack.pop().expect("non-empty stack");
928                    let mut exprs = vec![];
929                    for expr in [expr1, expr2] {
930                        match expr {
931                            MirScalarExpr::CallVariadic { func: f, exprs: es } if f == func => {
932                                exprs.extend(es.into_iter());
933                            }
934                            expr => {
935                                exprs.push(expr);
936                            }
937                        }
938                    }
939                    stack.push(MirScalarExpr::CallVariadic { func, exprs });
940                }
941            }
942        }
943
944        if stack.len() != 1 {
945            let msg = "Cannot fold postfix vector into a single MirScalarExpr";
946            Err(Error::new(input.span(), msg))?
947        }
948
949        Ok(stack.pop().unwrap())
950    }
951
952    pub fn parse_operand(input: ParseStream) -> Result {
953        let lookahead = input.lookahead1();
954        if lookahead.peek(syn::Token![#]) {
955            parse_column(input)
956        } else if lookahead.peek(syn::Lit) || lookahead.peek(kw::null) {
957            parse_literal_ok(input)
958        } else if lookahead.peek(kw::error) {
959            parse_literal_err(input)
960        } else if lookahead.peek(kw::array) {
961            parse_array(input)
962        } else if lookahead.peek(kw::list) {
963            parse_list(input)
964        } else if lookahead.peek(kw::case) {
965            parse_case(input)
966        } else if lookahead.peek(syn::Ident) {
967            parse_apply(input)
968        } else if lookahead.peek(syn::token::Brace) {
969            let inner;
970            syn::braced!(inner in input);
971            parse_literal_array(&inner)
972        } else if lookahead.peek(syn::token::Bracket) {
973            let inner;
974            syn::bracketed!(inner in input);
975            parse_literal_list(&inner)
976        } else if lookahead.peek(syn::token::Paren) {
977            let inner;
978            syn::parenthesized!(inner in input);
979            parse_expr(&inner)
980        } else {
981            Err(lookahead.error())
982        }
983    }
984
985    /// Parses `case when {cond} then {then} else {els} end`.
986    fn parse_case(input: ParseStream) -> Result {
987        input.parse::<kw::case>()?;
988        if input.peek(kw::when) {
989            input.parse::<kw::when>()?;
990            let cond = parse_expr(input)?;
991            input.parse::<kw::then>()?;
992            let then = parse_expr(input)?;
993            input.parse::<syn::Token![else]>()?;
994            let els = parse_expr(input)?;
995            input.parse::<kw::end>()?;
996            Ok(MirScalarExpr::If {
997                cond: Box::new(cond),
998                then: Box::new(then),
999                els: Box::new(els),
1000            })
1001        } else {
1002            Err(Error::new(input.span(), "expected 'when' after 'case'"))
1003        }
1004    }
1005
1006    pub fn parse_column(input: ParseStream) -> Result {
1007        Ok(MirScalarExpr::column(parse_column_index(input)?))
1008    }
1009
1010    pub fn parse_column_index(input: ParseStream) -> syn::Result<usize> {
1011        input.parse::<syn::Token![#]>()?;
1012        input.parse::<syn::LitInt>()?.base10_parse::<usize>()
1013    }
1014
1015    pub fn parse_column_order(input: ParseStream) -> syn::Result<ColumnOrder> {
1016        input.parse::<syn::Token![#]>()?;
1017        let column = input.parse::<syn::LitInt>()?.base10_parse::<usize>()?;
1018        let desc = input.eat(kw::desc) || !input.eat(kw::asc);
1019        let nulls_last = input.eat(kw::nulls_last) || !input.eat(kw::nulls_first);
1020        Ok(ColumnOrder {
1021            column,
1022            desc,
1023            nulls_last,
1024        })
1025    }
1026
1027    fn parse_literal_ok(input: ParseStream) -> Result {
1028        let mut row = Row::default();
1029        let mut packer = row.packer();
1030
1031        let typ = if input.eat(kw::null) {
1032            packer.push(Datum::Null);
1033            input.parse::<syn::Token![::]>()?;
1034            ReprColumnType {
1035                scalar_type: analyses::parse_scalar_type(input)?,
1036                nullable: true,
1037            }
1038        } else {
1039            match input.parse::<syn::Lit>()? {
1040                syn::Lit::Str(l) => {
1041                    packer.push(Datum::from(l.value().as_str()));
1042                    Ok(ReprColumnType::from(&String::as_column_type()))
1043                }
1044                syn::Lit::Int(l) => {
1045                    packer.push(Datum::from(l.base10_parse::<i64>()?));
1046                    Ok(ReprColumnType::from(&i64::as_column_type()))
1047                }
1048                syn::Lit::Float(l) => {
1049                    packer.push(Datum::from(l.base10_parse::<f64>()?));
1050                    Ok(ReprColumnType::from(&f64::as_column_type()))
1051                }
1052                syn::Lit::Bool(l) => {
1053                    packer.push(Datum::from(l.value));
1054                    Ok(ReprColumnType::from(&bool::as_column_type()))
1055                }
1056                _ => Err(Error::new(input.span(), "cannot parse literal")),
1057            }?
1058        };
1059
1060        Ok(MirScalarExpr::Literal(Ok(row), typ))
1061    }
1062    fn parse_literal_err(input: ParseStream) -> Result {
1063        input.parse::<kw::error>()?;
1064        let mut msg = {
1065            let content;
1066            syn::parenthesized!(content in input);
1067            content.parse::<syn::LitStr>()?.value()
1068        };
1069        let err = if msg.starts_with("internal error: ") {
1070            Ok(mz_expr::EvalError::Internal(msg.split_off(16).into()))
1071        } else {
1072            Err(Error::new(msg.span(), "expected `internal error: $msg`"))
1073        }?;
1074        Ok(MirScalarExpr::literal(Err(err), ReprScalarType::Bool))
1075    }
1076
1077    fn parse_literal_array(input: ParseStream) -> Result {
1078        use mz_expr::func::variadic::ArrayCreate;
1079
1080        let elem_type = SqlScalarType::Int64; // FIXME
1081        let func = VariadicFunc::ArrayCreate(ArrayCreate { elem_type });
1082        let exprs = input.parse_comma_sep(parse_literal_ok)?;
1083
1084        // Evaluate into a datum
1085        let temp_storage = RowArena::default();
1086        let datum = func.eval(&[], &temp_storage, &exprs).expect("datum");
1087        let typ = ReprScalarType::from(&SqlScalarType::Array(Box::new(SqlScalarType::Int64))); // FIXME
1088        Ok(MirScalarExpr::literal_ok(datum, typ))
1089    }
1090    fn parse_literal_list(input: ParseStream) -> Result {
1091        use mz_expr::func::variadic::ListCreate;
1092
1093        let elem_type = SqlScalarType::Int64; // FIXME
1094        let func = VariadicFunc::ListCreate(ListCreate { elem_type });
1095        let exprs = input.parse_comma_sep(parse_literal_ok)?;
1096
1097        // Evaluate into a datum
1098        let temp_storage = RowArena::default();
1099        let datum = func.eval(&[], &temp_storage, &exprs).expect("datum");
1100        let typ = ReprScalarType::from(&SqlScalarType::Array(Box::new(SqlScalarType::Int64))); // FIXME
1101        Ok(MirScalarExpr::literal_ok(datum, typ))
1102    }
1103    fn parse_array(input: ParseStream) -> Result {
1104        use mz_expr::func::variadic::ArrayCreate;
1105
1106        input.parse::<kw::array>()?;
1107
1108        // parse brackets
1109        let inner;
1110        syn::bracketed!(inner in input);
1111
1112        let elem_type = SqlScalarType::Int64; // FIXME
1113        let func = ArrayCreate { elem_type };
1114        let exprs = inner.parse_comma_sep(parse_expr)?;
1115
1116        Ok(MirScalarExpr::call_variadic(func, exprs))
1117    }
1118
1119    fn parse_list(input: ParseStream) -> Result {
1120        use mz_expr::func::variadic::ListCreate;
1121
1122        input.parse::<kw::list>()?;
1123
1124        // parse brackets
1125        let inner;
1126        syn::bracketed!(inner in input);
1127
1128        let elem_type = SqlScalarType::Int64; // FIXME
1129        let func = ListCreate { elem_type };
1130        let exprs = inner.parse_comma_sep(parse_expr)?;
1131
1132        Ok(MirScalarExpr::call_variadic(func, exprs))
1133    }
1134
1135    fn parse_apply(input: ParseStream) -> Result {
1136        let ident = input.parse::<syn::Ident>()?;
1137
1138        // parse parentheses
1139        let inner;
1140        syn::parenthesized!(inner in input);
1141
1142        let parse_nullary = |func: UnmaterializableFunc| -> Result {
1143            Ok(MirScalarExpr::CallUnmaterializable(func))
1144        };
1145        let parse_unary = |func: UnaryFunc| -> Result {
1146            let expr = Box::new(parse_expr(&inner)?);
1147            Ok(MirScalarExpr::CallUnary { func, expr })
1148        };
1149        let parse_binary = |func: BinaryFunc| -> Result {
1150            let expr1 = Box::new(parse_expr(&inner)?);
1151            inner.parse::<syn::Token![,]>()?;
1152            let expr2 = Box::new(parse_expr(&inner)?);
1153            Ok(MirScalarExpr::CallBinary { func, expr1, expr2 })
1154        };
1155        let parse_variadic = |func: VariadicFunc| -> Result {
1156            let exprs = inner.parse_comma_sep(parse_expr)?;
1157            Ok(MirScalarExpr::call_variadic(func, exprs))
1158        };
1159
1160        // Infix binary and variadic function calls are handled in `parse_scalar_expr`.
1161        //
1162        // Some restrictions apply with the current state of the code,
1163        // most notably one cannot handle overloaded function names because we don't want to do
1164        // name resolution in the parser.
1165        match ident.to_string().to_lowercase().as_str() {
1166            // Supported unmaterializable (a.k.a. nullary) functions:
1167            "mz_environment_id" => parse_nullary(UnmaterializableFunc::MzEnvironmentId),
1168            // Supported unary functions:
1169            "abs" => parse_unary(func::AbsInt64.into()),
1170            "not" => parse_unary(func::Not.into()),
1171            // Supported binary functions:
1172            "ltrim" => parse_binary(func::TrimLeading.into()),
1173            // Supported variadic functions:
1174            "greatest" => parse_variadic(VariadicFunc::Greatest(func::variadic::Greatest)),
1175            _ => Err(Error::new(ident.span(), "unsupported function name")),
1176        }
1177    }
1178
1179    pub fn parse_join_equivalences(input: ParseStream) -> syn::Result<Vec<Vec<MirScalarExpr>>> {
1180        let mut equivalences = vec![];
1181        while !input.is_empty() {
1182            let mut equivalence = vec![];
1183            loop {
1184                let mut worklist = vec![parse_operand(input)?];
1185                while let Some(operand) = worklist.pop() {
1186                    // Be more lenient and support parenthesized equivalences,
1187                    // e.g. `... AND (x = u + v = z + 1) AND ...`.
1188                    if let MirScalarExpr::CallBinary {
1189                        func: BinaryFunc::Eq(_),
1190                        expr1,
1191                        expr2,
1192                    } = operand
1193                    {
1194                        // We reverse the order in the worklist in order to get
1195                        // the correct order in the equivalence class.
1196                        worklist.push(*expr2);
1197                        worklist.push(*expr1);
1198                    } else {
1199                        equivalence.push(operand);
1200                    }
1201                }
1202                if !input.eat(syn::Token![=]) {
1203                    break;
1204                }
1205            }
1206            equivalences.push(equivalence);
1207            input.eat(kw::AND);
1208        }
1209        Ok(equivalences)
1210    }
1211}
1212
1213/// Support for parsing [mz_expr::AggregateExpr].
1214mod aggregate {
1215    use mz_expr::{AggregateExpr, MirScalarExpr};
1216
1217    use super::*;
1218
1219    type Result = syn::Result<AggregateExpr>;
1220
1221    pub fn parse_expr(input: ParseStream) -> Result {
1222        use mz_expr::AggregateFunc::*;
1223
1224        // Some restrictions apply with the current state of the code,
1225        // most notably one cannot handle overloaded function names because we don't want to do
1226        // name resolution in the parser.
1227        let ident = input.parse::<syn::Ident>()?;
1228        let func = match ident.to_string().to_lowercase().as_str() {
1229            "count" => Count,
1230            "any" => Any,
1231            "all" => All,
1232            "max" => MaxInt64,
1233            "min" => MinInt64,
1234            "sum" => SumInt64,
1235            _ => Err(Error::new(ident.span(), "unsupported function name"))?,
1236        };
1237
1238        // parse parentheses
1239        let inner;
1240        syn::parenthesized!(inner in input);
1241
1242        if func == Count && inner.eat(syn::Token![*]) {
1243            Ok(AggregateExpr {
1244                func,
1245                expr: MirScalarExpr::literal_true(),
1246                distinct: false, // TODO: fix explain output
1247            })
1248        } else {
1249            let distinct = inner.eat(kw::distinct);
1250            let expr = scalar::parse_expr(&inner)?;
1251            Ok(AggregateExpr {
1252                func,
1253                expr,
1254                distinct,
1255            })
1256        }
1257    }
1258}
1259
1260/// Support for parsing [mz_repr::Row].
1261mod row {
1262    use mz_repr::{Datum, Row, RowPacker};
1263
1264    use super::*;
1265
1266    impl Parse for Parsed<Row> {
1267        fn parse(input: ParseStream) -> syn::Result<Self> {
1268            let mut row = Row::default();
1269            let mut packer = ParseRow::new(&mut row);
1270
1271            loop {
1272                if input.is_empty() {
1273                    break;
1274                }
1275                packer.parse_datum(input)?;
1276                if input.is_empty() {
1277                    break;
1278                }
1279                input.parse::<syn::Token![,]>()?;
1280            }
1281
1282            Ok(Parsed(row))
1283        }
1284    }
1285
1286    impl From<Parsed<Row>> for Row {
1287        fn from(parsed: Parsed<Row>) -> Self {
1288            parsed.0
1289        }
1290    }
1291
1292    struct ParseRow<'a>(RowPacker<'a>);
1293
1294    impl<'a> ParseRow<'a> {
1295        fn new(row: &'a mut Row) -> Self {
1296            Self(row.packer())
1297        }
1298
1299        fn parse_datum(&mut self, input: ParseStream) -> syn::Result<()> {
1300            if input.eat(kw::null) {
1301                self.0.push(Datum::Null)
1302            } else {
1303                match input.parse::<syn::Lit>()? {
1304                    syn::Lit::Str(l) => self.0.push(Datum::from(l.value().as_str())),
1305                    syn::Lit::Int(l) => self.0.push(Datum::from(l.base10_parse::<i64>()?)),
1306                    syn::Lit::Float(l) => self.0.push(Datum::from(l.base10_parse::<f64>()?)),
1307                    syn::Lit::Bool(l) => self.0.push(Datum::from(l.value)),
1308                    _ => Err(Error::new(input.span(), "cannot parse literal"))?,
1309                }
1310            }
1311            Ok(())
1312        }
1313    }
1314}
1315
1316mod analyses {
1317    use mz_repr::{ReprColumnType, ReprScalarType};
1318
1319    use super::*;
1320
1321    #[derive(Default)]
1322    pub struct Analyses {
1323        pub types: Option<Vec<ReprColumnType>>,
1324        pub keys: Option<Vec<Vec<usize>>>,
1325    }
1326
1327    pub fn parse_analyses(input: ParseStream) -> syn::Result<Analyses> {
1328        let mut analyses = Analyses::default();
1329
1330        // Analyses are optional, appearing after a `//` at the end of the
1331        // line. However, since the syn lexer eats comments, we assume that `//`
1332        // was replaced with `::` upfront.
1333        if input.eat(syn::Token![::]) {
1334            let inner;
1335            syn::braced!(inner in input);
1336
1337            let (start, end) = (inner.span().start(), inner.span().end());
1338            if start.line != end.line {
1339                let msg = "analyses should not span more than one line".to_string();
1340                Err(Error::new(inner.span(), msg))?
1341            }
1342
1343            while inner.peek(syn::Ident) {
1344                let ident = inner.parse::<syn::Ident>()?.to_string();
1345                match ident.as_str() {
1346                    "types" => {
1347                        inner.parse::<syn::Token![:]>()?;
1348                        let value = inner.parse::<syn::LitStr>()?.value();
1349                        analyses.types = Some(parse_types.parse_str(&value)?);
1350                    }
1351                    // TODO: support keys
1352                    key => {
1353                        let msg = format!("unexpected analysis type `{}`", key);
1354                        Err(Error::new(inner.span(), msg))?;
1355                    }
1356                }
1357            }
1358        }
1359        Ok(analyses)
1360    }
1361
1362    fn parse_types(input: ParseStream) -> syn::Result<Vec<ReprColumnType>> {
1363        let inner;
1364        syn::parenthesized!(inner in input);
1365        inner.parse_comma_sep(parse_column_type)
1366    }
1367
1368    pub fn parse_column_type(input: ParseStream) -> syn::Result<ReprColumnType> {
1369        let scalar_type = parse_scalar_type(input)?;
1370        Ok(ReprColumnType {
1371            scalar_type,
1372            nullable: input.eat(syn::Token![?]),
1373        })
1374    }
1375
1376    pub fn parse_scalar_type(input: ParseStream) -> syn::Result<ReprScalarType> {
1377        let lookahead = input.lookahead1();
1378
1379        let scalar_type = if input.look_and_eat(bigint, &lookahead) {
1380            ReprScalarType::Int64
1381        } else if input.look_and_eat(double, &lookahead) {
1382            input.parse::<precision>()?;
1383            ReprScalarType::Float64
1384        } else if input.look_and_eat(boolean, &lookahead) {
1385            ReprScalarType::Bool
1386        } else if input.look_and_eat(character, &lookahead) {
1387            input.parse::<varying>()?;
1388            ReprScalarType::String
1389        } else if input.look_and_eat(integer, &lookahead) {
1390            ReprScalarType::Int32
1391        } else if input.look_and_eat(smallint, &lookahead) {
1392            ReprScalarType::Int16
1393        } else if input.look_and_eat(text, &lookahead) {
1394            ReprScalarType::String
1395        } else {
1396            Err(lookahead.error())?
1397        };
1398
1399        Ok(scalar_type)
1400    }
1401
1402    syn::custom_keyword!(bigint);
1403    syn::custom_keyword!(boolean);
1404    syn::custom_keyword!(character);
1405    syn::custom_keyword!(double);
1406    syn::custom_keyword!(integer);
1407    syn::custom_keyword!(precision);
1408    syn::custom_keyword!(smallint);
1409    syn::custom_keyword!(text);
1410    syn::custom_keyword!(varying);
1411}
1412
1413pub enum Def {
1414    Source {
1415        name: String,
1416        cols: Vec<String>,
1417        typ: mz_repr::SqlRelationType,
1418    },
1419}
1420
1421mod def {
1422    use mz_repr::{SqlColumnType, SqlRelationType};
1423
1424    use super::*;
1425
1426    pub fn parse_def(ctx: CtxRef, input: ParseStream) -> syn::Result<Def> {
1427        parse_def_source(ctx, input) // only one variant for now
1428    }
1429
1430    fn parse_def_source(ctx: CtxRef, input: ParseStream) -> syn::Result<Def> {
1431        let reduce = input.parse::<def::DefSource>()?;
1432
1433        let name = {
1434            input.parse::<def::name>()?;
1435            input.parse::<syn::Token![=]>()?;
1436            input.parse::<syn::Ident>()?.to_string()
1437        };
1438
1439        let keys = if input.eat(kw::keys) {
1440            input.parse::<syn::Token![=]>()?;
1441            let inner;
1442            syn::bracketed!(inner in input);
1443            inner.parse_comma_sep(|input| {
1444                let inner;
1445                syn::bracketed!(inner in input);
1446                inner.parse_comma_sep(scalar::parse_column_index)
1447            })?
1448        } else {
1449            vec![]
1450        };
1451
1452        let parse_inputs = ParseChildren::new(input, reduce.span().start());
1453        let (cols, column_types) = {
1454            let source_columns = parse_inputs.parse_many(ctx, parse_def_source_column)?;
1455            let mut column_names = vec![];
1456            let mut column_types = vec![];
1457            for (column_name, column_type) in source_columns {
1458                column_names.push(column_name);
1459                column_types.push(column_type);
1460            }
1461            (column_names, column_types)
1462        };
1463
1464        let typ = SqlRelationType { column_types, keys };
1465
1466        Ok(Def::Source { name, cols, typ })
1467    }
1468
1469    fn parse_def_source_column(
1470        _ctx: CtxRef,
1471        input: ParseStream,
1472    ) -> syn::Result<(String, SqlColumnType)> {
1473        input.parse::<syn::Token![-]>()?;
1474        let column_name = input.parse::<syn::Ident>()?.to_string();
1475        input.parse::<syn::Token![:]>()?;
1476        let column_type = SqlColumnType::from_repr(&analyses::parse_column_type(input)?);
1477        Ok((column_name, column_type))
1478    }
1479
1480    syn::custom_keyword!(DefSource);
1481    syn::custom_keyword!(name);
1482}
1483
1484/// Help utilities used by sibling modules.
1485mod util {
1486    use syn::parse::{Lookahead1, ParseBuffer, Peek};
1487
1488    use super::*;
1489
1490    /// Extension methods for [`syn::parse::ParseBuffer`].
1491    pub trait ParseBufferExt<'a> {
1492        fn look_and_eat<T: Eat>(&self, token: T, lookahead: &Lookahead1<'a>) -> bool;
1493
1494        /// Consumes a token `T` if present.
1495        fn eat<T: Eat>(&self, t: T) -> bool;
1496
1497        /// Consumes two tokens `T1 T2` if present in that order.
1498        fn eat2<T1: Eat, T2: Eat>(&self, t1: T1, t2: T2) -> bool;
1499
1500        /// Consumes three tokens `T1 T2 T3` if present in that order.
1501        fn eat3<T1: Eat, T2: Eat, T3: Eat>(&self, t1: T1, t2: T2, t3: T3) -> bool;
1502
1503        // Parse a comma-separated list of items into a vector.
1504        fn parse_comma_sep<T>(&self, p: fn(ParseStream) -> syn::Result<T>) -> syn::Result<Vec<T>>;
1505    }
1506
1507    impl<'a> ParseBufferExt<'a> for ParseBuffer<'a> {
1508        /// Consumes a token `T` if present, looking it up using the provided
1509        /// [`Lookahead1`] instance.
1510        fn look_and_eat<T: Eat>(&self, token: T, lookahead: &Lookahead1<'a>) -> bool {
1511            if lookahead.peek(token) {
1512                self.parse::<T::Token>().unwrap();
1513                true
1514            } else {
1515                false
1516            }
1517        }
1518
1519        fn eat<T: Eat>(&self, t: T) -> bool {
1520            if self.peek(t) {
1521                self.parse::<T::Token>().unwrap();
1522                true
1523            } else {
1524                false
1525            }
1526        }
1527
1528        fn eat2<T1: Eat, T2: Eat>(&self, t1: T1, t2: T2) -> bool {
1529            if self.peek(t1) && self.peek2(t2) {
1530                self.parse::<T1::Token>().unwrap();
1531                self.parse::<T2::Token>().unwrap();
1532                true
1533            } else {
1534                false
1535            }
1536        }
1537
1538        fn eat3<T1: Eat, T2: Eat, T3: Eat>(&self, t1: T1, t2: T2, t3: T3) -> bool {
1539            if self.peek(t1) && self.peek2(t2) && self.peek3(t3) {
1540                self.parse::<T1::Token>().unwrap();
1541                self.parse::<T2::Token>().unwrap();
1542                self.parse::<T3::Token>().unwrap();
1543                true
1544            } else {
1545                false
1546            }
1547        }
1548
1549        fn parse_comma_sep<T>(&self, p: fn(ParseStream) -> syn::Result<T>) -> syn::Result<Vec<T>> {
1550            Ok(self
1551                .parse_terminated(p, syn::Token![,])?
1552                .into_iter()
1553                .collect::<Vec<_>>())
1554        }
1555    }
1556
1557    // Helper trait for types that can be eaten.
1558    //
1559    // Implementing types must also implement [`Peek`], and the associated
1560    // [`Peek::Token`] type should implement [`Parse`]). For some reason the
1561    // latter bound is not present in [`Peek`] even if it makes a lot of sense,
1562    // which is why we need this helper.
1563    pub trait Eat: Peek<Token = Self::_Token> {
1564        type _Token: Parse;
1565    }
1566
1567    impl<T> Eat for T
1568    where
1569        T: Peek,
1570        T::Token: Parse,
1571    {
1572        type _Token = T::Token;
1573    }
1574
1575    pub struct Ctx<'a> {
1576        pub catalog: &'a TestCatalog,
1577    }
1578
1579    pub type CtxRef<'a> = &'a Ctx<'a>;
1580
1581    /// Newtype for external types that need to implement [Parse].
1582    pub struct Parsed<T>(pub T);
1583
1584    /// Provides facilities for parsing
1585    pub struct ParseChildren<'a> {
1586        stream: ParseStream<'a>,
1587        parent: LineColumn,
1588    }
1589
1590    impl<'a> ParseChildren<'a> {
1591        pub fn new(stream: ParseStream<'a>, parent: LineColumn) -> Self {
1592            Self { stream, parent }
1593        }
1594
1595        pub fn parse_one<C, T>(
1596            &self,
1597            ctx: C,
1598            function: fn(C, ParseStream) -> syn::Result<T>,
1599        ) -> syn::Result<T> {
1600            match self.maybe_child() {
1601                Ok(_) => function(ctx, self.stream),
1602                Err(e) => Err(e),
1603            }
1604        }
1605
1606        pub fn parse_many<C: Copy, T>(
1607            &self,
1608            ctx: C,
1609            function: fn(C, ParseStream) -> syn::Result<T>,
1610        ) -> syn::Result<Vec<T>> {
1611            let mut inputs = vec![self.parse_one(ctx, function)?];
1612            while self.maybe_child().is_ok() {
1613                inputs.push(function(ctx, self.stream)?);
1614            }
1615            Ok(inputs)
1616        }
1617
1618        fn maybe_child(&self) -> syn::Result<()> {
1619            let start = self.stream.span().start();
1620            if start.line <= self.parent.line {
1621                let msg = format!("child expected at line > {}", self.parent.line);
1622                Err(Error::new(self.stream.span(), msg))?
1623            }
1624            if start.column != self.parent.column + 2 {
1625                let msg = format!("child expected at column {}", self.parent.column + 2);
1626                Err(Error::new(self.stream.span(), msg))?
1627            }
1628            Ok(())
1629        }
1630    }
1631}
1632
1633/// Custom keywords used while parsing.
1634mod kw {
1635    syn::custom_keyword!(aggregates);
1636    syn::custom_keyword!(AND);
1637    syn::custom_keyword!(ArrangeBy);
1638    syn::custom_keyword!(array);
1639    syn::custom_keyword!(asc);
1640    // case when ... then ... else ... end
1641    syn::custom_keyword!(case);
1642    syn::custom_keyword!(Constant);
1643    syn::custom_keyword!(CrossJoin);
1644    syn::custom_keyword!(cte);
1645    syn::custom_keyword!(desc);
1646    syn::custom_keyword!(distinct);
1647    syn::custom_keyword!(Distinct);
1648    syn::custom_keyword!(empty);
1649    syn::custom_keyword!(end);
1650    syn::custom_keyword!(eq);
1651    syn::custom_keyword!(error);
1652    syn::custom_keyword!(exp_group_size);
1653    syn::custom_keyword!(FALSE);
1654    syn::custom_keyword!(Filter);
1655    syn::custom_keyword!(FlatMap);
1656    syn::custom_keyword!(Get);
1657    syn::custom_keyword!(group_by);
1658    syn::custom_keyword!(IS);
1659    syn::custom_keyword!(Join);
1660    syn::custom_keyword!(keys);
1661    syn::custom_keyword!(limit);
1662    syn::custom_keyword!(list);
1663    syn::custom_keyword!(Map);
1664    syn::custom_keyword!(monotonic);
1665    syn::custom_keyword!(Mutually);
1666    syn::custom_keyword!(Negate);
1667    syn::custom_keyword!(NOT);
1668    syn::custom_keyword!(null);
1669    syn::custom_keyword!(NULL);
1670    syn::custom_keyword!(nulls_first);
1671    syn::custom_keyword!(nulls_last);
1672    syn::custom_keyword!(offset);
1673    syn::custom_keyword!(on);
1674    syn::custom_keyword!(OR);
1675    syn::custom_keyword!(order_by);
1676    syn::custom_keyword!(project);
1677    syn::custom_keyword!(Project);
1678    syn::custom_keyword!(Recursive);
1679    syn::custom_keyword!(Reduce);
1680    syn::custom_keyword!(Return);
1681    syn::custom_keyword!(then);
1682    syn::custom_keyword!(Threshold);
1683    syn::custom_keyword!(TopK);
1684    syn::custom_keyword!(TRUE);
1685    syn::custom_keyword!(Union);
1686    syn::custom_keyword!(when);
1687    syn::custom_keyword!(With);
1688    syn::custom_keyword!(x);
1689}