mz_sql/plan/
transform_ast.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL ASTs.
11//!
12//! Most query optimizations are performed by the dataflow layer, but some
13//! are much easier to perform in SQL. Someday, we'll want our own SQL IR,
14//! but for now we just use the parser's AST directly.
15
16use mz_ore::stack::{CheckedRecursion, RecursionGuard};
17use mz_repr::namespaces::{MZ_CATALOG_SCHEMA, MZ_UNSAFE_SCHEMA, PG_CATALOG_SCHEMA};
18use mz_sql_parser::ast::visit_mut::{self, VisitMut, VisitMutNode};
19use mz_sql_parser::ast::{
20    Expr, Function, FunctionArgs, HomogenizingFunction, Ident, IsExprConstruct, Op, OrderByExpr,
21    Query, Select, SelectItem, TableAlias, TableFactor, TableWithJoins, Value, WindowSpec,
22};
23use mz_sql_parser::ident;
24use uuid::Uuid;
25
26use crate::names::{Aug, PartialItemName, ResolvedDataType, ResolvedItemName};
27use crate::normalize;
28use crate::plan::{PlanError, StatementContext};
29
30pub(crate) fn transform<N>(scx: &StatementContext, node: &mut N) -> Result<(), PlanError>
31where
32    N: for<'a> VisitMutNode<'a, Aug>,
33{
34    let mut func_rewriter = FuncRewriter::new(scx);
35    node.visit_mut(&mut func_rewriter);
36    func_rewriter.status?;
37
38    let mut desugarer = Desugarer::new(scx);
39    node.visit_mut(&mut desugarer);
40    desugarer.status
41}
42
43// Transforms various functions to forms that are more easily handled by the
44// planner.
45//
46// Specifically:
47//
48//   * Rewrites the `mod` function to the `%` binary operator, so the modulus
49//     code only needs to handle the operator form.
50//
51//   * Rewrites the `nullif` function to a `CASE` statement, to reuse the code
52//     for planning equality of datums.
53//
54//   * Rewrites `avg(col)` to `sum(col) / count(col)`, so that we can pretend
55//     the `avg` aggregate function doesn't exist from here on out. This also
56//     has the nice side effect of reusing the division planning logic, which
57//     is not trivial for some types, like decimals.
58//
59//   * Rewrites the suite of standard deviation and variance functions in a
60//     manner similar to `avg`.
61//
62// TODO(sploiselle): rewrite these in terms of func::sql_op!
63struct FuncRewriter<'a> {
64    scx: &'a StatementContext<'a>,
65    status: Result<(), PlanError>,
66    rewriting_table_factor: bool,
67}
68
69impl<'a> FuncRewriter<'a> {
70    fn new(scx: &'a StatementContext<'a>) -> FuncRewriter<'a> {
71        FuncRewriter {
72            scx,
73            status: Ok(()),
74            rewriting_table_factor: false,
75        }
76    }
77
78    fn resolve_known_valid_data_type(&self, name: &PartialItemName) -> ResolvedDataType {
79        let item = self
80            .scx
81            .catalog
82            .resolve_type(name)
83            .expect("data type known to be valid");
84        let full_name = self.scx.catalog.resolve_full_name(item.name());
85        ResolvedDataType::Named {
86            id: item.id(),
87            qualifiers: item.name().qualifiers.clone(),
88            full_name,
89            modifiers: vec![],
90            print_id: true,
91        }
92    }
93
94    fn int32_data_type(&self) -> ResolvedDataType {
95        self.resolve_known_valid_data_type(&PartialItemName {
96            database: None,
97            schema: Some(PG_CATALOG_SCHEMA.into()),
98            item: "int4".into(),
99        })
100    }
101
102    // Divides `lhs` by `rhs` but replaces division-by-zero errors with NULL;
103    // note that this is semantically equivalent to `NULLIF(rhs, 0)`.
104    fn plan_divide(lhs: Expr<Aug>, rhs: Expr<Aug>) -> Expr<Aug> {
105        lhs.divide(Expr::Case {
106            operand: None,
107            conditions: vec![rhs.clone().equals(Expr::number("0"))],
108            results: vec![Expr::null()],
109            else_result: Some(Box::new(rhs)),
110        })
111    }
112
113    fn plan_agg(
114        &mut self,
115        name: ResolvedItemName,
116        expr: Expr<Aug>,
117        order_by: Vec<OrderByExpr<Aug>>,
118        filter: Option<Box<Expr<Aug>>>,
119        distinct: bool,
120        over: Option<WindowSpec<Aug>>,
121    ) -> Expr<Aug> {
122        if self.rewriting_table_factor && self.status.is_ok() {
123            self.status = Err(PlanError::Unstructured(
124                "aggregate functions are not supported in functions in FROM".to_string(),
125            ))
126        }
127        Expr::Function(Function {
128            name,
129            args: FunctionArgs::Args {
130                args: vec![expr],
131                order_by,
132            },
133            filter,
134            over,
135            distinct,
136        })
137    }
138
139    fn plan_avg(
140        &mut self,
141        expr: Expr<Aug>,
142        filter: Option<Box<Expr<Aug>>>,
143        distinct: bool,
144        over: Option<WindowSpec<Aug>>,
145    ) -> Expr<Aug> {
146        let sum = self
147            .plan_agg(
148                self.scx
149                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
150                expr.clone(),
151                vec![],
152                filter.clone(),
153                distinct,
154                over.clone(),
155            )
156            .call_unary(
157                self.scx
158                    .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
159            );
160        let count = self.plan_agg(
161            self.scx
162                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
163            expr,
164            vec![],
165            filter,
166            distinct,
167            over,
168        );
169        Self::plan_divide(sum, count)
170    }
171
172    /// Same as `plan_avg` but internally uses `mz_avg_promotion_internal_v1`.
173    fn plan_avg_internal_v1(
174        &mut self,
175        expr: Expr<Aug>,
176        filter: Option<Box<Expr<Aug>>>,
177        distinct: bool,
178        over: Option<WindowSpec<Aug>>,
179    ) -> Expr<Aug> {
180        let sum = self
181            .plan_agg(
182                self.scx
183                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
184                expr.clone(),
185                vec![],
186                filter.clone(),
187                distinct,
188                over.clone(),
189            )
190            .call_unary(
191                self.scx
192                    .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion_internal_v1"]),
193            );
194        let count = self.plan_agg(
195            self.scx
196                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
197            expr,
198            vec![],
199            filter,
200            distinct,
201            over,
202        );
203        Self::plan_divide(sum, count)
204    }
205
206    fn plan_variance(
207        &mut self,
208        expr: Expr<Aug>,
209        filter: Option<Box<Expr<Aug>>>,
210        distinct: bool,
211        sample: bool,
212        over: Option<WindowSpec<Aug>>,
213    ) -> Expr<Aug> {
214        // N.B. this variance calculation uses the "textbook" algorithm, which
215        // is known to accumulate problematic amounts of error. The numerically
216        // stable variants, the most well-known of which is Welford's, are
217        // however difficult to implement inside of Differential Dataflow, as
218        // they do not obviously support retractions efficiently (database-issues#436).
219        //
220        // The code below converts var_samp(x) into
221        //
222        //     (sum(x²) - sum(x)² / count(x)) / (count(x) - 1)
223        //
224        // and var_pop(x) into:
225        //
226        //     (sum(x²) - sum(x)² / count(x)) / count(x)
227        //
228        let expr = expr.call_unary(
229            self.scx
230                .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
231        );
232        let expr_squared = expr.clone().multiply(expr.clone());
233        let sum_squares = self.plan_agg(
234            self.scx
235                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
236            expr_squared,
237            vec![],
238            filter.clone(),
239            distinct,
240            over.clone(),
241        );
242        let sum = self.plan_agg(
243            self.scx
244                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
245            expr.clone(),
246            vec![],
247            filter.clone(),
248            distinct,
249            over.clone(),
250        );
251        let sum_squared = sum.clone().multiply(sum);
252        let count = self.plan_agg(
253            self.scx
254                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
255            expr,
256            vec![],
257            filter,
258            distinct,
259            over,
260        );
261        let result = Self::plan_divide(
262            sum_squares.minus(Self::plan_divide(sum_squared, count.clone())),
263            if sample {
264                count.minus(Expr::number("1"))
265            } else {
266                count
267            },
268        );
269        // Result is _basically_ what we want, except
270        // that due to numerical inaccuracy, it might be a negative
271        // number very close to zero when it should mathematically be zero.
272        // This makes it so `stddev` fails as it tries to take the square root
273        // of a negative number.
274        // So, we need the following logic:
275        // If `result` is NULL, return NULL (no surprise here)
276        // Otherwise, if `result` is >0, return `result` (no surprise here either)
277        // Otherwise, return 0.
278        //
279        // Unfortunately, we can't use `GREATEST` directly for this,
280        // since `greatest(NULL, 0)` is 0, not NULL, so we need to
281        // create a `Case` expression that computes `result`
282        // twice. Hopefully the optimizer can deal with this!
283        let result_is_null = Expr::IsExpr {
284            expr: Box::new(result.clone()),
285            construct: IsExprConstruct::Null,
286            negated: false,
287        };
288        Expr::Case {
289            operand: None,
290            conditions: vec![result_is_null],
291            results: vec![Expr::Value(Value::Null)],
292            else_result: Some(Box::new(Expr::HomogenizingFunction {
293                function: HomogenizingFunction::Greatest,
294                exprs: vec![result, Expr::number("0")],
295            })),
296        }
297    }
298
299    fn plan_stddev(
300        &mut self,
301        expr: Expr<Aug>,
302        filter: Option<Box<Expr<Aug>>>,
303        distinct: bool,
304        sample: bool,
305        over: Option<WindowSpec<Aug>>,
306    ) -> Expr<Aug> {
307        self.plan_variance(expr, filter, distinct, sample, over)
308            .call_unary(
309                self.scx
310                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sqrt"]),
311            )
312    }
313
314    fn plan_bool_and(
315        &mut self,
316        expr: Expr<Aug>,
317        filter: Option<Box<Expr<Aug>>>,
318        distinct: bool,
319        over: Option<WindowSpec<Aug>>,
320    ) -> Expr<Aug> {
321        // The code below converts `bool_and(x)` into:
322        //
323        //     sum((NOT x)::int4) = 0
324        //
325        // It is tempting to use `count` instead, but count does not return NULL
326        // when all input values are NULL, as required.
327        //
328        // The `NOT x` expression has the side effect of implicitly casting `x`
329        // to `bool`. We intentionally do not write `NOT x::bool`, because that
330        // would perform an explicit cast, and to match PostgreSQL we must
331        // perform only an implicit cast.
332        let sum = self.plan_agg(
333            self.scx
334                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
335            expr.negate().cast(self.int32_data_type()),
336            vec![],
337            filter,
338            distinct,
339            over,
340        );
341        sum.equals(Expr::Value(Value::Number(0.to_string())))
342    }
343
344    fn plan_bool_or(
345        &mut self,
346        expr: Expr<Aug>,
347        filter: Option<Box<Expr<Aug>>>,
348        distinct: bool,
349        over: Option<WindowSpec<Aug>>,
350    ) -> Expr<Aug> {
351        // The code below converts `bool_or(x)`z into:
352        //
353        //     sum((x OR false)::int4) > 0
354        //
355        // It is tempting to use `count` instead, but count does not return NULL
356        // when all input values are NULL, as required.
357        //
358        // The `(x OR false)` expression implicitly casts `x` to `bool` without
359        // changing its logical value. It is tempting to use `x::bool` instead,
360        // but that performs an explicit cast, and to match PostgreSQL we must
361        // perform only an implicit cast.
362        let sum = self.plan_agg(
363            self.scx
364                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
365            expr.or(Expr::Value(Value::Boolean(false)))
366                .cast(self.int32_data_type()),
367            vec![],
368            filter,
369            distinct,
370            over,
371        );
372        sum.gt(Expr::Value(Value::Number(0.to_string())))
373    }
374
375    fn rewrite_function(&mut self, func: &Function<Aug>) -> Option<(Ident, Expr<Aug>)> {
376        if let Function {
377            name,
378            args: FunctionArgs::Args { args, order_by: _ },
379            filter,
380            distinct,
381            over,
382        } = func
383        {
384            let pg_catalog_id = self
385                .scx
386                .catalog
387                .resolve_schema(None, PG_CATALOG_SCHEMA)
388                .expect("pg_catalog schema exists")
389                .id();
390            let mz_catalog_id = self
391                .scx
392                .catalog
393                .resolve_schema(None, MZ_CATALOG_SCHEMA)
394                .expect("mz_catalog schema exists")
395                .id();
396            let name = match name {
397                ResolvedItemName::Item {
398                    qualifiers,
399                    full_name,
400                    ..
401                } => {
402                    if ![*pg_catalog_id, *mz_catalog_id].contains(&qualifiers.schema_spec) {
403                        return None;
404                    }
405                    full_name.item.clone()
406                }
407                _ => unreachable!(),
408            };
409
410            let filter = filter.clone();
411            let distinct = *distinct;
412            let over = over.clone();
413            let expr = if args.len() == 1 {
414                let arg = args[0].clone();
415                match name.as_str() {
416                    "avg_internal_v1" => self.plan_avg_internal_v1(arg, filter, distinct, over),
417                    "avg" => self.plan_avg(arg, filter, distinct, over),
418                    "variance" | "var_samp" => {
419                        self.plan_variance(arg, filter, distinct, true, over)
420                    }
421                    "var_pop" => self.plan_variance(arg, filter, distinct, false, over),
422                    "stddev" | "stddev_samp" => self.plan_stddev(arg, filter, distinct, true, over),
423                    "stddev_pop" => self.plan_stddev(arg, filter, distinct, false, over),
424                    "bool_and" => self.plan_bool_and(arg, filter, distinct, over),
425                    "bool_or" => self.plan_bool_or(arg, filter, distinct, over),
426                    _ => return None,
427                }
428            } else if args.len() == 2 {
429                let (lhs, rhs) = (args[0].clone(), args[1].clone());
430                match name.as_str() {
431                    "mod" => lhs.modulo(rhs),
432                    "pow" => Expr::call(
433                        self.scx
434                            .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "power"]),
435                        vec![lhs, rhs],
436                    ),
437                    _ => return None,
438                }
439            } else {
440                return None;
441            };
442            Some((Ident::new_unchecked(name), expr))
443        } else {
444            None
445        }
446    }
447
448    fn rewrite_expr(&mut self, expr: &Expr<Aug>) -> Option<(Ident, Expr<Aug>)> {
449        match expr {
450            Expr::Function(function) => self.rewrite_function(function),
451            // Rewrites special keywords that SQL considers to be function calls
452            // to actual function calls. For example, `SELECT current_timestamp`
453            // is rewritten to `SELECT current_timestamp()`.
454            Expr::Identifier(ident) if ident.len() == 1 => {
455                let ident = normalize::ident(ident[0].clone());
456                let fn_ident = match ident.as_str() {
457                    "current_role" => Some("current_user"),
458                    "current_schema" | "current_timestamp" | "current_user" | "session_user" => {
459                        Some(ident.as_str())
460                    }
461                    _ => None,
462                };
463                match fn_ident {
464                    None => None,
465                    Some(fn_ident) => {
466                        let expr = Expr::call_nullary(
467                            self.scx
468                                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, fn_ident]),
469                        );
470                        Some((Ident::new_unchecked(ident), expr))
471                    }
472                }
473            }
474            _ => None,
475        }
476    }
477}
478
479impl<'ast> VisitMut<'ast, Aug> for FuncRewriter<'_> {
480    fn visit_select_item_mut(&mut self, item: &'ast mut SelectItem<Aug>) {
481        if let SelectItem::Expr { expr, alias: None } = item {
482            visit_mut::visit_expr_mut(self, expr);
483            if let Some((alias, expr)) = self.rewrite_expr(expr) {
484                *item = SelectItem::Expr {
485                    expr,
486                    alias: Some(alias),
487                };
488            }
489        } else {
490            visit_mut::visit_select_item_mut(self, item);
491        }
492    }
493
494    fn visit_table_with_joins_mut(&mut self, item: &'ast mut TableWithJoins<Aug>) {
495        visit_mut::visit_table_with_joins_mut(self, item);
496        match &mut item.relation {
497            TableFactor::Function {
498                function,
499                alias,
500                with_ordinality,
501            } => {
502                self.rewriting_table_factor = true;
503                // Functions that get rewritten must be rewritten as exprs
504                // because their catalog functions cannot be planned.
505                if let Some((ident, expr)) = self.rewrite_function(function) {
506                    let mut select = Select::default().project(SelectItem::Expr {
507                        expr,
508                        alias: Some(match &alias {
509                            Some(TableAlias { name, columns, .. }) => {
510                                columns.get(0).unwrap_or(name).clone()
511                            }
512                            None => ident,
513                        }),
514                    });
515
516                    if *with_ordinality {
517                        select = select.project(SelectItem::Expr {
518                            expr: Expr::Value(Value::Number("1".into())),
519                            alias: Some(ident!("ordinality")),
520                        });
521                    }
522
523                    item.relation = TableFactor::Derived {
524                        lateral: false,
525                        subquery: Box::new(Query {
526                            ctes: mz_sql_parser::ast::CteBlock::Simple(vec![]),
527                            body: mz_sql_parser::ast::SetExpr::Select(Box::new(select)),
528                            order_by: vec![],
529                            limit: None,
530                            offset: None,
531                        }),
532                        alias: alias.clone(),
533                    }
534                }
535                self.rewriting_table_factor = false;
536            }
537            _ => {}
538        }
539    }
540
541    fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
542        visit_mut::visit_expr_mut(self, expr);
543        if let Some((_name, new_expr)) = self.rewrite_expr(expr) {
544            *expr = new_expr;
545        }
546    }
547}
548
549/// Removes syntax sugar to simplify the planner.
550///
551/// For example, `<expr> NOT IN (<subquery>)` is rewritten to `expr <> ALL
552/// (<subquery>)`.
553struct Desugarer<'a> {
554    scx: &'a StatementContext<'a>,
555    status: Result<(), PlanError>,
556    recursion_guard: RecursionGuard,
557}
558
559impl<'a> CheckedRecursion for Desugarer<'a> {
560    fn recursion_guard(&self) -> &RecursionGuard {
561        &self.recursion_guard
562    }
563}
564
565impl<'a, 'ast> VisitMut<'ast, Aug> for Desugarer<'a> {
566    fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
567        self.visit_internal(Self::visit_expr_mut_internal, expr);
568    }
569}
570
571impl<'a> Desugarer<'a> {
572    fn visit_internal<F, X>(&mut self, f: F, x: X)
573    where
574        F: Fn(&mut Self, X) -> Result<(), PlanError>,
575    {
576        if self.status.is_ok() {
577            // self.status could have changed from a deeper call, so don't blindly
578            // overwrite it with the result of this call.
579            let status = self.checked_recur_mut(|d| f(d, x));
580            if self.status.is_ok() {
581                self.status = status;
582            }
583        }
584    }
585
586    fn new(scx: &'a StatementContext) -> Desugarer<'a> {
587        Desugarer {
588            scx,
589            status: Ok(()),
590            recursion_guard: RecursionGuard::with_limit(1024), // chosen arbitrarily
591        }
592    }
593
594    fn visit_expr_mut_internal(&mut self, expr: &mut Expr<Aug>) -> Result<(), PlanError> {
595        // `($expr)` => `$expr`
596        while let Expr::Nested(e) = expr {
597            *expr = e.take();
598        }
599
600        // `$expr BETWEEN $low AND $high` => `$expr >= $low AND $expr <= $low`
601        // `$expr NOT BETWEEN $low AND $high` => `$expr < $low OR $expr > $low`
602        if let Expr::Between {
603            expr: e,
604            low,
605            high,
606            negated,
607        } = expr
608        {
609            if *negated {
610                *expr = Expr::lt(*e.clone(), low.take()).or(e.take().gt(high.take()));
611            } else {
612                *expr = e.clone().gt_eq(low.take()).and(e.take().lt_eq(high.take()));
613            }
614        }
615
616        // When `$expr` is a `ROW` constructor, we need to desugar as described
617        // below in order to enable the row comparision expansion at the end of
618        // this function. We don't do this desugaring unconditionally (i.e.,
619        // when `$expr` is not a `ROW` constructor) because the implementation
620        // in `plan_in_list` is more efficient when row comparison expansion is
621        // not required.
622        //
623        // `$expr IN ($list)` => `$expr = $list[0] OR $expr = $list[1] ... OR $expr = $list[n]`
624        // `$expr NOT IN ($list)` => `$expr <> $list[0] AND $expr <> $list[1] ... AND $expr <> $list[n]`
625        if let Expr::InList {
626            expr: e,
627            list,
628            negated,
629        } = expr
630        {
631            if let Expr::Row { .. } = &**e {
632                if *negated {
633                    *expr = list
634                        .drain(..)
635                        .map(|r| e.clone().not_equals(r))
636                        .reduce(|e1, e2| e1.and(e2))
637                        .expect("list known to contain at least one element");
638                } else {
639                    *expr = list
640                        .drain(..)
641                        .map(|r| e.clone().equals(r))
642                        .reduce(|e1, e2| e1.or(e2))
643                        .expect("list known to contain at least one element");
644                }
645            }
646        }
647
648        // `$expr IN ($subquery)` => `$expr = ANY ($subquery)`
649        // `$expr NOT IN ($subquery)` => `$expr <> ALL ($subquery)`
650        if let Expr::InSubquery {
651            expr: e,
652            subquery,
653            negated,
654        } = expr
655        {
656            if *negated {
657                *expr = Expr::AllSubquery {
658                    left: Box::new(e.take()),
659                    op: Op::bare("<>"),
660                    right: Box::new(subquery.take()),
661                };
662            } else {
663                *expr = Expr::AnySubquery {
664                    left: Box::new(e.take()),
665                    op: Op::bare("="),
666                    right: Box::new(subquery.take()),
667                };
668            }
669        }
670
671        // `$expr = ALL ($array_expr)`
672        // =>
673        // `$expr = ALL (SELECT elem FROM unnest($array_expr) _ (elem))`
674        //
675        // and analogously for other operators and ANY.
676        if let Expr::AnyExpr { left, op, right } | Expr::AllExpr { left, op, right } = expr {
677            let binding = ident!("elem");
678
679            let subquery = Query::select(
680                Select::default()
681                    .from(TableWithJoins {
682                        relation: TableFactor::Function {
683                            function: Function {
684                                name: self
685                                    .scx
686                                    .dangerous_resolve_name(vec![MZ_CATALOG_SCHEMA, "unnest"]),
687                                args: FunctionArgs::args(vec![right.take()]),
688                                filter: None,
689                                over: None,
690                                distinct: false,
691                            },
692                            alias: Some(TableAlias {
693                                name: ident!("_"),
694                                columns: vec![binding.clone()],
695                                strict: true,
696                            }),
697                            with_ordinality: false,
698                        },
699                        joins: vec![],
700                    })
701                    .project(SelectItem::Expr {
702                        expr: Expr::Identifier(vec![binding]),
703                        alias: None,
704                    }),
705            );
706
707            let left = Box::new(left.take());
708
709            let op = op.clone();
710
711            *expr = match expr {
712                Expr::AnyExpr { .. } => Expr::AnySubquery {
713                    left,
714                    op,
715                    right: Box::new(subquery),
716                },
717                Expr::AllExpr { .. } => Expr::AllSubquery {
718                    left,
719                    op,
720                    right: Box::new(subquery),
721                },
722                _ => unreachable!(),
723            };
724        }
725
726        // `$expr = ALL ($subquery)`
727        // =>
728        // `(SELECT mz_unsafe.mz_all($expr = $binding) FROM ($subquery) AS _ ($binding))
729        //
730        // and analogously for other operators and ANY.
731        if let Expr::AnySubquery { left, op, right } | Expr::AllSubquery { left, op, right } = expr
732        {
733            let left = match &mut **left {
734                Expr::Row { .. } => left.take(),
735                _ => Expr::Row {
736                    exprs: vec![left.take()],
737                },
738            };
739
740            let arity = match &left {
741                Expr::Row { exprs } => exprs.len(),
742                _ => unreachable!(),
743            };
744
745            let bindings: Vec<_> = (0..arity)
746                // Note: using unchecked is okay here because we know the value will be less than
747                // our maximum length.
748                .map(|_| Ident::new_unchecked(format!("right_{}", Uuid::new_v4())))
749                .collect();
750
751            let select = Select::default()
752                .from(TableWithJoins::subquery(
753                    right.take(),
754                    TableAlias {
755                        name: ident!("subquery"),
756                        columns: bindings.clone(),
757                        strict: true,
758                    },
759                ))
760                .project(SelectItem::Expr {
761                    expr: left
762                        .binop(
763                            op.clone(),
764                            Expr::Row {
765                                exprs: bindings
766                                    .into_iter()
767                                    .map(|b| Expr::Identifier(vec![b]))
768                                    .collect(),
769                            },
770                        )
771                        .call_unary(self.scx.dangerous_resolve_name(match expr {
772                            Expr::AnySubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_any"],
773                            Expr::AllSubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_all"],
774                            _ => unreachable!(),
775                        })),
776                    alias: None,
777                });
778
779            *expr = Expr::Subquery(Box::new(Query::select(select)));
780        }
781
782        // Expands row comparisons.
783        //
784        // ROW($l1, $l2, ..., $ln) = ROW($r1, $r2, ..., $rn)
785        // =>
786        // $l1 = $r1 AND $l2 = $r2 AND ... AND $ln = $rn
787        //
788        // ROW($l1, $l2, ..., $ln) < ROW($r1, $r2, ..., $rn)
789        // =>
790        // $l1 < $r1 OR ($l1 = $r1 AND ($l2 < $r2 OR ($l2 = $r2 AND ... ($ln < $rn))))
791        //
792        // ROW($l1, $l2, ..., $ln) <= ROW($r1, $r2, ..., $rn)
793        // =>
794        // $l1 < $r1 OR ($l1 = $r1 AND ($l2 < $r2 OR ($l2 = $r2 AND ... ($ln <= $rn))))
795        //
796        // and analogously for the inverse operations !=, >, and >=.
797        if let Expr::Op {
798            op,
799            expr1: left,
800            expr2: Some(right),
801        } = expr
802        {
803            if let (Expr::Row { exprs: left }, Expr::Row { exprs: right }) =
804                (&mut **left, &mut **right)
805            {
806                if matches!(normalize::op(op)?, "=" | "<>" | "<" | "<=" | ">" | ">=") {
807                    if left.len() != right.len() {
808                        sql_bail!("unequal number of entries in row expressions");
809                    }
810                    if left.is_empty() {
811                        assert!(right.is_empty());
812                        sql_bail!("cannot compare rows of zero length");
813                    }
814                }
815                match normalize::op(op)? {
816                    "=" | "<>" => {
817                        let mut pairs = left.iter_mut().zip(right);
818                        let mut new = pairs
819                            .next()
820                            .map(|(l, r)| l.take().equals(r.take()))
821                            .expect("cannot compare rows of zero length");
822                        for (l, r) in pairs {
823                            new = l.take().equals(r.take()).and(new);
824                        }
825                        if normalize::op(op)? == "<>" {
826                            new = new.negate();
827                        }
828                        *expr = new;
829                    }
830                    "<" | "<=" | ">" | ">=" => {
831                        let strict_op = match normalize::op(op)? {
832                            "<" | "<=" => "<",
833                            ">" | ">=" => ">",
834                            _ => unreachable!(),
835                        };
836                        let (l, r) = (left.last_mut().unwrap(), right.last_mut().unwrap());
837                        let mut new = l.take().binop(op.clone(), r.take());
838                        for (l, r) in left.iter_mut().zip(right).rev().skip(1) {
839                            new = l
840                                .clone()
841                                .binop(Op::bare(strict_op), r.clone())
842                                .or(l.take().equals(r.take()).and(new));
843                        }
844                        *expr = new;
845                    }
846                    _ if left.len() == 1 && right.len() == 1 => {
847                        let left = left.remove(0);
848                        let right = right.remove(0);
849                        *expr = left.binop(op.clone(), right);
850                    }
851                    _ => (),
852                }
853            }
854        }
855
856        visit_mut::visit_expr_mut(self, expr);
857        Ok(())
858    }
859}