mz_sql/plan/
transform_ast.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL ASTs.
11//!
12//! Most query optimizations are performed by the dataflow layer, but some
13//! are much easier to perform in SQL. Someday, we'll want our own SQL IR,
14//! but for now we just use the parser's AST directly.
15
16use itertools::Itertools;
17use mz_ore::id_gen::IdGen;
18use mz_ore::stack::{CheckedRecursion, RecursionGuard};
19use mz_repr::namespaces::{MZ_CATALOG_SCHEMA, MZ_UNSAFE_SCHEMA, PG_CATALOG_SCHEMA};
20use mz_sql_parser::ast::visit_mut::{self, VisitMut, VisitMutNode};
21use mz_sql_parser::ast::{
22    Expr, Function, FunctionArgs, HomogenizingFunction, Ident, IsExprConstruct, Op, OrderByExpr,
23    Query, Select, SelectItem, TableAlias, TableFactor, TableWithJoins, Value, WindowSpec,
24};
25use mz_sql_parser::ident;
26
27use crate::names::{Aug, PartialItemName, ResolvedDataType, ResolvedItemName};
28use crate::plan::{PlanError, StatementContext};
29use crate::{ORDINALITY_COL_NAME, normalize};
30
31pub(crate) fn transform<N>(scx: &StatementContext, node: &mut N) -> Result<(), PlanError>
32where
33    N: for<'a> VisitMutNode<'a, Aug>,
34{
35    let mut func_rewriter = FuncRewriter::new(scx);
36    node.visit_mut(&mut func_rewriter);
37    func_rewriter.status?;
38
39    let mut desugarer = Desugarer::new(scx);
40    node.visit_mut(&mut desugarer);
41    desugarer.status
42}
43
44// Transforms various functions to forms that are more easily handled by the
45// planner.
46//
47// Specifically:
48//
49//   * Rewrites the `mod` function to the `%` binary operator, so the modulus
50//     code only needs to handle the operator form.
51//
52//   * Rewrites the `nullif` function to a `CASE` statement, to reuse the code
53//     for planning equality of datums.
54//
55//   * Rewrites `avg(col)` to `sum(col) / count(col)`, so that we can pretend
56//     the `avg` aggregate function doesn't exist from here on out. This also
57//     has the nice side effect of reusing the division planning logic, which
58//     is not trivial for some types, like decimals.
59//
60//   * Rewrites the suite of standard deviation and variance functions in a
61//     manner similar to `avg`.
62//
63// TODO(sploiselle): rewrite these in terms of func::sql_op!
64struct FuncRewriter<'a> {
65    scx: &'a StatementContext<'a>,
66    status: Result<(), PlanError>,
67    rewriting_table_factor: bool,
68}
69
70impl<'a> FuncRewriter<'a> {
71    fn new(scx: &'a StatementContext<'a>) -> FuncRewriter<'a> {
72        FuncRewriter {
73            scx,
74            status: Ok(()),
75            rewriting_table_factor: false,
76        }
77    }
78
79    fn resolve_known_valid_data_type(&self, name: &PartialItemName) -> ResolvedDataType {
80        let item = self
81            .scx
82            .catalog
83            .resolve_type(name)
84            .expect("data type known to be valid");
85        let full_name = self.scx.catalog.resolve_full_name(item.name());
86        ResolvedDataType::Named {
87            id: item.id(),
88            qualifiers: item.name().qualifiers.clone(),
89            full_name,
90            modifiers: vec![],
91            print_id: true,
92        }
93    }
94
95    fn int32_data_type(&self) -> ResolvedDataType {
96        self.resolve_known_valid_data_type(&PartialItemName {
97            database: None,
98            schema: Some(PG_CATALOG_SCHEMA.into()),
99            item: "int4".into(),
100        })
101    }
102
103    // Divides `lhs` by `rhs` but replaces division-by-zero errors with NULL;
104    // note that this is semantically equivalent to `NULLIF(rhs, 0)`.
105    fn plan_divide(lhs: Expr<Aug>, rhs: Expr<Aug>) -> Expr<Aug> {
106        lhs.divide(Expr::Case {
107            operand: None,
108            conditions: vec![rhs.clone().equals(Expr::number("0"))],
109            results: vec![Expr::null()],
110            else_result: Some(Box::new(rhs)),
111        })
112    }
113
114    fn plan_agg(
115        &mut self,
116        name: ResolvedItemName,
117        expr: Expr<Aug>,
118        order_by: Vec<OrderByExpr<Aug>>,
119        filter: Option<Box<Expr<Aug>>>,
120        distinct: bool,
121        over: Option<WindowSpec<Aug>>,
122    ) -> Expr<Aug> {
123        if self.rewriting_table_factor && self.status.is_ok() {
124            self.status = Err(PlanError::Unstructured(
125                "aggregate functions are not supported in functions in FROM".to_string(),
126            ))
127        }
128        Expr::Function(Function {
129            name,
130            args: FunctionArgs::Args {
131                args: vec![expr],
132                order_by,
133            },
134            filter,
135            over,
136            distinct,
137        })
138    }
139
140    fn plan_avg(
141        &mut self,
142        expr: Expr<Aug>,
143        filter: Option<Box<Expr<Aug>>>,
144        distinct: bool,
145        over: Option<WindowSpec<Aug>>,
146    ) -> Expr<Aug> {
147        let sum = self
148            .plan_agg(
149                self.scx
150                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
151                expr.clone(),
152                vec![],
153                filter.clone(),
154                distinct,
155                over.clone(),
156            )
157            .call_unary(
158                self.scx
159                    .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
160            );
161        let count = self.plan_agg(
162            self.scx
163                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
164            expr,
165            vec![],
166            filter,
167            distinct,
168            over,
169        );
170        Self::plan_divide(sum, count)
171    }
172
173    /// Same as `plan_avg` but internally uses `mz_avg_promotion_internal_v1`.
174    fn plan_avg_internal_v1(
175        &mut self,
176        expr: Expr<Aug>,
177        filter: Option<Box<Expr<Aug>>>,
178        distinct: bool,
179        over: Option<WindowSpec<Aug>>,
180    ) -> Expr<Aug> {
181        let sum = self
182            .plan_agg(
183                self.scx
184                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
185                expr.clone(),
186                vec![],
187                filter.clone(),
188                distinct,
189                over.clone(),
190            )
191            .call_unary(
192                self.scx
193                    .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion_internal_v1"]),
194            );
195        let count = self.plan_agg(
196            self.scx
197                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
198            expr,
199            vec![],
200            filter,
201            distinct,
202            over,
203        );
204        Self::plan_divide(sum, count)
205    }
206
207    fn plan_variance(
208        &mut self,
209        expr: Expr<Aug>,
210        filter: Option<Box<Expr<Aug>>>,
211        distinct: bool,
212        sample: bool,
213        over: Option<WindowSpec<Aug>>,
214    ) -> Expr<Aug> {
215        // N.B. this variance calculation uses the "textbook" algorithm, which
216        // is known to accumulate problematic amounts of error. The numerically
217        // stable variants, the most well-known of which is Welford's, are
218        // however difficult to implement inside of Differential Dataflow, as
219        // they do not obviously support retractions efficiently (database-issues#436).
220        //
221        // The code below converts var_samp(x) into
222        //
223        //     (sum(x²) - sum(x)² / count(x)) / (count(x) - 1)
224        //
225        // and var_pop(x) into:
226        //
227        //     (sum(x²) - sum(x)² / count(x)) / count(x)
228        //
229        let expr = expr.call_unary(
230            self.scx
231                .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
232        );
233        let expr_squared = expr.clone().multiply(expr.clone());
234        let sum_squares = self.plan_agg(
235            self.scx
236                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
237            expr_squared,
238            vec![],
239            filter.clone(),
240            distinct,
241            over.clone(),
242        );
243        let sum = self.plan_agg(
244            self.scx
245                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
246            expr.clone(),
247            vec![],
248            filter.clone(),
249            distinct,
250            over.clone(),
251        );
252        let sum_squared = sum.clone().multiply(sum);
253        let count = self.plan_agg(
254            self.scx
255                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
256            expr,
257            vec![],
258            filter,
259            distinct,
260            over,
261        );
262        let result = Self::plan_divide(
263            sum_squares.minus(Self::plan_divide(sum_squared, count.clone())),
264            if sample {
265                count.minus(Expr::number("1"))
266            } else {
267                count
268            },
269        );
270        // Result is _basically_ what we want, except
271        // that due to numerical inaccuracy, it might be a negative
272        // number very close to zero when it should mathematically be zero.
273        // This makes it so `stddev` fails as it tries to take the square root
274        // of a negative number.
275        // So, we need the following logic:
276        // If `result` is NULL, return NULL (no surprise here)
277        // Otherwise, if `result` is >0, return `result` (no surprise here either)
278        // Otherwise, return 0.
279        //
280        // Unfortunately, we can't use `GREATEST` directly for this,
281        // since `greatest(NULL, 0)` is 0, not NULL, so we need to
282        // create a `Case` expression that computes `result`
283        // twice. Hopefully the optimizer can deal with this!
284        let result_is_null = Expr::IsExpr {
285            expr: Box::new(result.clone()),
286            construct: IsExprConstruct::Null,
287            negated: false,
288        };
289        Expr::Case {
290            operand: None,
291            conditions: vec![result_is_null],
292            results: vec![Expr::Value(Value::Null)],
293            else_result: Some(Box::new(Expr::HomogenizingFunction {
294                function: HomogenizingFunction::Greatest,
295                exprs: vec![result, Expr::number("0")],
296            })),
297        }
298    }
299
300    fn plan_stddev(
301        &mut self,
302        expr: Expr<Aug>,
303        filter: Option<Box<Expr<Aug>>>,
304        distinct: bool,
305        sample: bool,
306        over: Option<WindowSpec<Aug>>,
307    ) -> Expr<Aug> {
308        self.plan_variance(expr, filter, distinct, sample, over)
309            .call_unary(
310                self.scx
311                    .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sqrt"]),
312            )
313    }
314
315    fn plan_bool_and(
316        &mut self,
317        expr: Expr<Aug>,
318        filter: Option<Box<Expr<Aug>>>,
319        distinct: bool,
320        over: Option<WindowSpec<Aug>>,
321    ) -> Expr<Aug> {
322        // The code below converts `bool_and(x)` into:
323        //
324        //     sum((NOT x)::int4) = 0
325        //
326        // It is tempting to use `count` instead, but count does not return NULL
327        // when all input values are NULL, as required.
328        //
329        // The `NOT x` expression has the side effect of implicitly casting `x`
330        // to `bool`. We intentionally do not write `NOT x::bool`, because that
331        // would perform an explicit cast, and to match PostgreSQL we must
332        // perform only an implicit cast.
333        let sum = self.plan_agg(
334            self.scx
335                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
336            expr.negate().cast(self.int32_data_type()),
337            vec![],
338            filter,
339            distinct,
340            over,
341        );
342        sum.equals(Expr::Value(Value::Number(0.to_string())))
343    }
344
345    fn plan_bool_or(
346        &mut self,
347        expr: Expr<Aug>,
348        filter: Option<Box<Expr<Aug>>>,
349        distinct: bool,
350        over: Option<WindowSpec<Aug>>,
351    ) -> Expr<Aug> {
352        // The code below converts `bool_or(x)`z into:
353        //
354        //     sum((x OR false)::int4) > 0
355        //
356        // It is tempting to use `count` instead, but count does not return NULL
357        // when all input values are NULL, as required.
358        //
359        // The `(x OR false)` expression implicitly casts `x` to `bool` without
360        // changing its logical value. It is tempting to use `x::bool` instead,
361        // but that performs an explicit cast, and to match PostgreSQL we must
362        // perform only an implicit cast.
363        let sum = self.plan_agg(
364            self.scx
365                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
366            expr.or(Expr::Value(Value::Boolean(false)))
367                .cast(self.int32_data_type()),
368            vec![],
369            filter,
370            distinct,
371            over,
372        );
373        sum.gt(Expr::Value(Value::Number(0.to_string())))
374    }
375
376    fn rewrite_function(&mut self, func: &Function<Aug>) -> Option<(Ident, Expr<Aug>)> {
377        if let Function {
378            name,
379            args: FunctionArgs::Args { args, order_by: _ },
380            filter,
381            distinct,
382            over,
383        } = func
384        {
385            let pg_catalog_id = self
386                .scx
387                .catalog
388                .resolve_schema(None, PG_CATALOG_SCHEMA)
389                .expect("pg_catalog schema exists")
390                .id();
391            let mz_catalog_id = self
392                .scx
393                .catalog
394                .resolve_schema(None, MZ_CATALOG_SCHEMA)
395                .expect("mz_catalog schema exists")
396                .id();
397            let name = match name {
398                ResolvedItemName::Item {
399                    qualifiers,
400                    full_name,
401                    ..
402                } => {
403                    if ![*pg_catalog_id, *mz_catalog_id].contains(&qualifiers.schema_spec) {
404                        return None;
405                    }
406                    full_name.item.clone()
407                }
408                _ => unreachable!(),
409            };
410
411            let filter = filter.clone();
412            let distinct = *distinct;
413            let over = over.clone();
414            let expr = if args.len() == 1 {
415                let arg = args[0].clone();
416                match name.as_str() {
417                    "avg_internal_v1" => self.plan_avg_internal_v1(arg, filter, distinct, over),
418                    "avg" => self.plan_avg(arg, filter, distinct, over),
419                    "variance" | "var_samp" => {
420                        self.plan_variance(arg, filter, distinct, true, over)
421                    }
422                    "var_pop" => self.plan_variance(arg, filter, distinct, false, over),
423                    "stddev" | "stddev_samp" => self.plan_stddev(arg, filter, distinct, true, over),
424                    "stddev_pop" => self.plan_stddev(arg, filter, distinct, false, over),
425                    "bool_and" => self.plan_bool_and(arg, filter, distinct, over),
426                    "bool_or" => self.plan_bool_or(arg, filter, distinct, over),
427                    _ => return None,
428                }
429            } else if args.len() == 2 {
430                let (lhs, rhs) = (args[0].clone(), args[1].clone());
431                match name.as_str() {
432                    "mod" => lhs.modulo(rhs),
433                    "pow" => Expr::call(
434                        self.scx
435                            .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "power"]),
436                        vec![lhs, rhs],
437                    ),
438                    _ => return None,
439                }
440            } else {
441                return None;
442            };
443            Some((Ident::new_unchecked(name), expr))
444        } else {
445            None
446        }
447    }
448
449    fn rewrite_expr(&mut self, expr: &Expr<Aug>) -> Option<(Ident, Expr<Aug>)> {
450        match expr {
451            Expr::Function(function) => self.rewrite_function(function),
452            // Rewrites special keywords that SQL considers to be function calls
453            // to actual function calls. For example, `SELECT current_timestamp`
454            // is rewritten to `SELECT current_timestamp()`.
455            Expr::Identifier(ident) if ident.len() == 1 => {
456                let ident = normalize::ident(ident[0].clone());
457                let fn_ident = match ident.as_str() {
458                    "current_role" => Some("current_user"),
459                    "current_schema" | "current_timestamp" | "current_user" | "session_user"
460                    | "current_catalog" => Some(ident.as_str()),
461                    _ => None,
462                };
463                match fn_ident {
464                    None => None,
465                    Some(fn_ident) => {
466                        let expr = Expr::call_nullary(
467                            self.scx
468                                .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, fn_ident]),
469                        );
470                        Some((Ident::new_unchecked(ident), expr))
471                    }
472                }
473            }
474            _ => None,
475        }
476    }
477}
478
479impl<'ast> VisitMut<'ast, Aug> for FuncRewriter<'_> {
480    fn visit_select_item_mut(&mut self, item: &'ast mut SelectItem<Aug>) {
481        if let SelectItem::Expr { expr, alias: None } = item {
482            visit_mut::visit_expr_mut(self, expr);
483            if let Some((alias, expr)) = self.rewrite_expr(expr) {
484                *item = SelectItem::Expr {
485                    expr,
486                    alias: Some(alias),
487                };
488            }
489        } else {
490            visit_mut::visit_select_item_mut(self, item);
491        }
492    }
493
494    fn visit_table_with_joins_mut(&mut self, item: &'ast mut TableWithJoins<Aug>) {
495        visit_mut::visit_table_with_joins_mut(self, item);
496        match &mut item.relation {
497            TableFactor::Function {
498                function,
499                alias,
500                with_ordinality,
501            } => {
502                self.rewriting_table_factor = true;
503                // Functions that get rewritten must be rewritten as exprs
504                // because their catalog functions cannot be planned.
505                if let Some((ident, expr)) = self.rewrite_function(function) {
506                    let mut select = Select::default().project(SelectItem::Expr {
507                        expr,
508                        alias: Some(match &alias {
509                            Some(TableAlias { name, columns, .. }) => {
510                                columns.get(0).unwrap_or(name).clone()
511                            }
512                            None => ident,
513                        }),
514                    });
515
516                    if *with_ordinality {
517                        select = select.project(SelectItem::Expr {
518                            expr: Expr::Value(Value::Number("1".into())),
519                            alias: Some(ident!(ORDINALITY_COL_NAME)),
520                        });
521                    }
522
523                    item.relation = TableFactor::Derived {
524                        lateral: false,
525                        subquery: Box::new(Query {
526                            ctes: mz_sql_parser::ast::CteBlock::Simple(vec![]),
527                            body: mz_sql_parser::ast::SetExpr::Select(Box::new(select)),
528                            order_by: vec![],
529                            limit: None,
530                            offset: None,
531                        }),
532                        alias: alias.clone(),
533                    }
534                }
535                self.rewriting_table_factor = false;
536            }
537            _ => {}
538        }
539    }
540
541    fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
542        visit_mut::visit_expr_mut(self, expr);
543        if let Some((_name, new_expr)) = self.rewrite_expr(expr) {
544            *expr = new_expr;
545        }
546    }
547}
548
549/// Removes syntax sugar to simplify the planner.
550///
551/// For example, `<expr> NOT IN (<subquery>)` is rewritten to `expr <> ALL
552/// (<subquery>)`.
553struct Desugarer<'a> {
554    scx: &'a StatementContext<'a>,
555    status: Result<(), PlanError>,
556    id_gen: IdGen,
557    recursion_guard: RecursionGuard,
558}
559
560impl<'a> CheckedRecursion for Desugarer<'a> {
561    fn recursion_guard(&self) -> &RecursionGuard {
562        &self.recursion_guard
563    }
564}
565
566impl<'a, 'ast> VisitMut<'ast, Aug> for Desugarer<'a> {
567    fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
568        self.visit_internal(Self::visit_expr_mut_internal, expr);
569    }
570}
571
572impl<'a> Desugarer<'a> {
573    fn visit_internal<F, X>(&mut self, f: F, x: X)
574    where
575        F: Fn(&mut Self, X) -> Result<(), PlanError>,
576    {
577        if self.status.is_ok() {
578            // self.status could have changed from a deeper call, so don't blindly
579            // overwrite it with the result of this call.
580            let status = self.checked_recur_mut(|d| f(d, x));
581            if self.status.is_ok() {
582                self.status = status;
583            }
584        }
585    }
586
587    fn new(scx: &'a StatementContext) -> Desugarer<'a> {
588        Desugarer {
589            scx,
590            status: Ok(()),
591            id_gen: Default::default(),
592            recursion_guard: RecursionGuard::with_limit(1024), // chosen arbitrarily
593        }
594    }
595
596    fn visit_expr_mut_internal(&mut self, expr: &mut Expr<Aug>) -> Result<(), PlanError> {
597        // `($expr)` => `$expr`
598        while let Expr::Nested(e) = expr {
599            *expr = e.take();
600        }
601
602        // `$expr BETWEEN $low AND $high` => `$expr >= $low AND $expr <= $low`
603        // `$expr NOT BETWEEN $low AND $high` => `$expr < $low OR $expr > $low`
604        if let Expr::Between {
605            expr: e,
606            low,
607            high,
608            negated,
609        } = expr
610        {
611            if *negated {
612                *expr = Expr::lt(*e.clone(), low.take()).or(e.take().gt(high.take()));
613            } else {
614                *expr = e.clone().gt_eq(low.take()).and(e.take().lt_eq(high.take()));
615            }
616        }
617
618        // When `$expr` is a `ROW` constructor, we need to desugar as described
619        // below in order to enable the row comparision expansion at the end of
620        // this function. We don't do this desugaring unconditionally (i.e.,
621        // when `$expr` is not a `ROW` constructor) because the implementation
622        // in `plan_in_list` is more efficient when row comparison expansion is
623        // not required.
624        //
625        // `$expr IN ($list)` => `$expr = $list[0] OR $expr = $list[1] ... OR $expr = $list[n]`
626        // `$expr NOT IN ($list)` => `$expr <> $list[0] AND $expr <> $list[1] ... AND $expr <> $list[n]`
627        if let Expr::InList {
628            expr: e,
629            list,
630            negated,
631        } = expr
632        {
633            if let Expr::Row { .. } = &**e {
634                if *negated {
635                    *expr = list
636                        .drain(..)
637                        .map(|r| e.clone().not_equals(r))
638                        .reduce(|e1, e2| e1.and(e2))
639                        .expect("list known to contain at least one element");
640                } else {
641                    *expr = list
642                        .drain(..)
643                        .map(|r| e.clone().equals(r))
644                        .reduce(|e1, e2| e1.or(e2))
645                        .expect("list known to contain at least one element");
646                }
647            }
648        }
649
650        // `$expr IN ($subquery)` => `$expr = ANY ($subquery)`
651        // `$expr NOT IN ($subquery)` => `$expr <> ALL ($subquery)`
652        if let Expr::InSubquery {
653            expr: e,
654            subquery,
655            negated,
656        } = expr
657        {
658            if *negated {
659                *expr = Expr::AllSubquery {
660                    left: Box::new(e.take()),
661                    op: Op::bare("<>"),
662                    right: Box::new(subquery.take()),
663                };
664            } else {
665                *expr = Expr::AnySubquery {
666                    left: Box::new(e.take()),
667                    op: Op::bare("="),
668                    right: Box::new(subquery.take()),
669                };
670            }
671        }
672
673        // `$expr = ALL ($array_expr)`
674        // =>
675        // `$expr = ALL (SELECT elem FROM unnest($array_expr) _ (elem))`
676        //
677        // and analogously for other operators and ANY.
678        if let Expr::AnyExpr { left, op, right } | Expr::AllExpr { left, op, right } = expr {
679            let binding = ident!("elem");
680
681            let subquery = Query::select(
682                Select::default()
683                    .from(TableWithJoins {
684                        relation: TableFactor::Function {
685                            function: Function {
686                                name: self
687                                    .scx
688                                    .dangerous_resolve_name(vec![MZ_CATALOG_SCHEMA, "unnest"]),
689                                args: FunctionArgs::args(vec![right.take()]),
690                                filter: None,
691                                over: None,
692                                distinct: false,
693                            },
694                            alias: Some(TableAlias {
695                                name: ident!("_"),
696                                columns: vec![binding.clone()],
697                                strict: true,
698                            }),
699                            with_ordinality: false,
700                        },
701                        joins: vec![],
702                    })
703                    .project(SelectItem::Expr {
704                        expr: Expr::Identifier(vec![binding]),
705                        alias: None,
706                    }),
707            );
708
709            let left = Box::new(left.take());
710
711            let op = op.clone();
712
713            *expr = match expr {
714                Expr::AnyExpr { .. } => Expr::AnySubquery {
715                    left,
716                    op,
717                    right: Box::new(subquery),
718                },
719                Expr::AllExpr { .. } => Expr::AllSubquery {
720                    left,
721                    op,
722                    right: Box::new(subquery),
723                },
724                _ => unreachable!(),
725            };
726        }
727
728        // `$expr = ALL ($subquery)`
729        // =>
730        // `(SELECT mz_unsafe.mz_all($expr = $binding) FROM ($subquery) AS _ ($binding))
731        //
732        // and analogously for other operators and ANY.
733        if let Expr::AnySubquery { left, op, right } | Expr::AllSubquery { left, op, right } = expr
734        {
735            let left = match &mut **left {
736                Expr::Row { .. } => left.take(),
737                _ => Expr::Row {
738                    exprs: vec![left.take()],
739                },
740            };
741
742            let arity = match &left {
743                Expr::Row { exprs } => exprs.len(),
744                _ => unreachable!(),
745            };
746
747            let bindings: Vec<_> = (0..arity)
748                // Note: using unchecked is okay here because we know the value will be less than
749                // our maximum length.
750                .map(|col| {
751                    let unique_id = self.id_gen.allocate_id();
752                    Ident::new_unchecked(format!("right_col{col}_{unique_id}"))
753                })
754                .collect();
755
756            let subquery_unique_id = self.id_gen.allocate_id();
757            // Note: kay to use unchecked here because we know the value will be small enough.
758            let subquery_name = Ident::new_unchecked(format!("subquery{subquery_unique_id}"));
759            let select = Select::default()
760                .from(TableWithJoins::subquery(
761                    right.take(),
762                    TableAlias {
763                        name: subquery_name,
764                        columns: bindings.clone(),
765                        strict: true,
766                    },
767                ))
768                .project(SelectItem::Expr {
769                    expr: left
770                        .binop(
771                            op.clone(),
772                            Expr::Row {
773                                exprs: bindings
774                                    .into_iter()
775                                    .map(|b| Expr::Identifier(vec![b]))
776                                    .collect(),
777                            },
778                        )
779                        .call_unary(self.scx.dangerous_resolve_name(match expr {
780                            Expr::AnySubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_any"],
781                            Expr::AllSubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_all"],
782                            _ => unreachable!(),
783                        })),
784                    alias: None,
785                });
786
787            *expr = Expr::Subquery(Box::new(Query::select(select)));
788        }
789
790        // Expands row comparisons.
791        //
792        // ROW($l1, $l2, ..., $ln) = ROW($r1, $r2, ..., $rn)
793        // =>
794        // $l1 = $r1 AND $l2 = $r2 AND ... AND $ln = $rn
795        //
796        // ROW($l1, $l2, ..., $ln) < ROW($r1, $r2, ..., $rn)
797        // =>
798        // $l1 < $r1 OR ($l1 = $r1 AND ($l2 < $r2 OR ($l2 = $r2 AND ... ($ln < $rn))))
799        //
800        // ROW($l1, $l2, ..., $ln) <= ROW($r1, $r2, ..., $rn)
801        // =>
802        // $l1 < $r1 OR ($l1 = $r1 AND ($l2 < $r2 OR ($l2 = $r2 AND ... ($ln <= $rn))))
803        //
804        // and analogously for the inverse operations !=, >, and >=.
805        if let Expr::Op {
806            op,
807            expr1: left,
808            expr2: Some(right),
809        } = expr
810        {
811            if let (Expr::Row { exprs: left }, Expr::Row { exprs: right }) =
812                (&mut **left, &mut **right)
813            {
814                if matches!(normalize::op(op)?, "=" | "<>" | "<" | "<=" | ">" | ">=") {
815                    if left.len() != right.len() {
816                        sql_bail!("unequal number of entries in row expressions");
817                    }
818                    if left.is_empty() {
819                        assert!(right.is_empty());
820                        sql_bail!("cannot compare rows of zero length");
821                    }
822                }
823                match normalize::op(op)? {
824                    "=" | "<>" => {
825                        let mut pairs = left.iter_mut().zip_eq(right);
826                        let mut new = pairs
827                            .next()
828                            .map(|(l, r)| l.take().equals(r.take()))
829                            .expect("cannot compare rows of zero length");
830                        for (l, r) in pairs {
831                            new = l.take().equals(r.take()).and(new);
832                        }
833                        if normalize::op(op)? == "<>" {
834                            new = new.negate();
835                        }
836                        *expr = new;
837                    }
838                    "<" | "<=" | ">" | ">=" => {
839                        let strict_op = match normalize::op(op)? {
840                            "<" | "<=" => "<",
841                            ">" | ">=" => ">",
842                            _ => unreachable!(),
843                        };
844                        let (l, r) = (left.last_mut().unwrap(), right.last_mut().unwrap());
845                        let mut new = l.take().binop(op.clone(), r.take());
846                        for (l, r) in left
847                            .iter_mut()
848                            .rev()
849                            .zip_eq(right.into_iter().rev())
850                            .skip(1)
851                        {
852                            new = l
853                                .clone()
854                                .binop(Op::bare(strict_op), r.clone())
855                                .or(l.take().equals(r.take()).and(new));
856                        }
857                        *expr = new;
858                    }
859                    _ if left.len() == 1 && right.len() == 1 => {
860                        let left = left.remove(0);
861                        let right = right.remove(0);
862                        *expr = left.binop(op.clone(), right);
863                    }
864                    _ => (),
865                }
866            }
867        }
868
869        visit_mut::visit_expr_mut(self, expr);
870        Ok(())
871    }
872}