1use itertools::Itertools;
17use mz_ore::id_gen::IdGen;
18use mz_ore::stack::{CheckedRecursion, RecursionGuard};
19use mz_repr::namespaces::{MZ_CATALOG_SCHEMA, MZ_UNSAFE_SCHEMA, PG_CATALOG_SCHEMA};
20use mz_sql_parser::ast::visit_mut::{self, VisitMut, VisitMutNode};
21use mz_sql_parser::ast::{
22 Expr, Function, FunctionArgs, HomogenizingFunction, Ident, IsExprConstruct, Op, OrderByExpr,
23 Query, Select, SelectItem, TableAlias, TableFactor, TableWithJoins, Value, WindowSpec,
24};
25use mz_sql_parser::ident;
26
27use crate::names::{Aug, PartialItemName, ResolvedDataType, ResolvedItemName};
28use crate::plan::{PlanError, StatementContext};
29use crate::{ORDINALITY_COL_NAME, normalize};
30
31pub(crate) fn transform<N>(scx: &StatementContext, node: &mut N) -> Result<(), PlanError>
32where
33 N: for<'a> VisitMutNode<'a, Aug>,
34{
35 let mut func_rewriter = FuncRewriter::new(scx);
36 node.visit_mut(&mut func_rewriter);
37 func_rewriter.status?;
38
39 let mut desugarer = Desugarer::new(scx);
40 node.visit_mut(&mut desugarer);
41 desugarer.status
42}
43
44struct FuncRewriter<'a> {
65 scx: &'a StatementContext<'a>,
66 status: Result<(), PlanError>,
67 rewriting_table_factor: bool,
68}
69
70impl<'a> FuncRewriter<'a> {
71 fn new(scx: &'a StatementContext<'a>) -> FuncRewriter<'a> {
72 FuncRewriter {
73 scx,
74 status: Ok(()),
75 rewriting_table_factor: false,
76 }
77 }
78
79 fn resolve_known_valid_data_type(&self, name: &PartialItemName) -> ResolvedDataType {
80 let item = self
81 .scx
82 .catalog
83 .resolve_type(name)
84 .expect("data type known to be valid");
85 let full_name = self.scx.catalog.resolve_full_name(item.name());
86 ResolvedDataType::Named {
87 id: item.id(),
88 qualifiers: item.name().qualifiers.clone(),
89 full_name,
90 modifiers: vec![],
91 print_id: true,
92 }
93 }
94
95 fn int32_data_type(&self) -> ResolvedDataType {
96 self.resolve_known_valid_data_type(&PartialItemName {
97 database: None,
98 schema: Some(PG_CATALOG_SCHEMA.into()),
99 item: "int4".into(),
100 })
101 }
102
103 fn plan_divide(lhs: Expr<Aug>, rhs: Expr<Aug>) -> Expr<Aug> {
106 lhs.divide(Expr::Case {
107 operand: None,
108 conditions: vec![rhs.clone().equals(Expr::number("0"))],
109 results: vec![Expr::null()],
110 else_result: Some(Box::new(rhs)),
111 })
112 }
113
114 fn plan_agg(
115 &mut self,
116 name: ResolvedItemName,
117 expr: Expr<Aug>,
118 order_by: Vec<OrderByExpr<Aug>>,
119 filter: Option<Box<Expr<Aug>>>,
120 distinct: bool,
121 over: Option<WindowSpec<Aug>>,
122 ) -> Expr<Aug> {
123 if self.rewriting_table_factor && self.status.is_ok() {
124 self.status = Err(PlanError::Unstructured(
125 "aggregate functions are not supported in functions in FROM".to_string(),
126 ))
127 }
128 Expr::Function(Function {
129 name,
130 args: FunctionArgs::Args {
131 args: vec![expr],
132 order_by,
133 },
134 filter,
135 over,
136 distinct,
137 })
138 }
139
140 fn plan_avg(
141 &mut self,
142 expr: Expr<Aug>,
143 filter: Option<Box<Expr<Aug>>>,
144 distinct: bool,
145 over: Option<WindowSpec<Aug>>,
146 ) -> Expr<Aug> {
147 let sum = self
148 .plan_agg(
149 self.scx
150 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
151 expr.clone(),
152 vec![],
153 filter.clone(),
154 distinct,
155 over.clone(),
156 )
157 .call_unary(
158 self.scx
159 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
160 );
161 let count = self.plan_agg(
162 self.scx
163 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
164 expr,
165 vec![],
166 filter,
167 distinct,
168 over,
169 );
170 Self::plan_divide(sum, count)
171 }
172
173 fn plan_avg_internal_v1(
175 &mut self,
176 expr: Expr<Aug>,
177 filter: Option<Box<Expr<Aug>>>,
178 distinct: bool,
179 over: Option<WindowSpec<Aug>>,
180 ) -> Expr<Aug> {
181 let sum = self
182 .plan_agg(
183 self.scx
184 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
185 expr.clone(),
186 vec![],
187 filter.clone(),
188 distinct,
189 over.clone(),
190 )
191 .call_unary(
192 self.scx
193 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion_internal_v1"]),
194 );
195 let count = self.plan_agg(
196 self.scx
197 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
198 expr,
199 vec![],
200 filter,
201 distinct,
202 over,
203 );
204 Self::plan_divide(sum, count)
205 }
206
207 fn plan_variance(
208 &mut self,
209 expr: Expr<Aug>,
210 filter: Option<Box<Expr<Aug>>>,
211 distinct: bool,
212 sample: bool,
213 over: Option<WindowSpec<Aug>>,
214 ) -> Expr<Aug> {
215 let expr = expr.call_unary(
230 self.scx
231 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
232 );
233 let expr_squared = expr.clone().multiply(expr.clone());
234 let sum_squares = self.plan_agg(
235 self.scx
236 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
237 expr_squared,
238 vec![],
239 filter.clone(),
240 distinct,
241 over.clone(),
242 );
243 let sum = self.plan_agg(
244 self.scx
245 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
246 expr.clone(),
247 vec![],
248 filter.clone(),
249 distinct,
250 over.clone(),
251 );
252 let sum_squared = sum.clone().multiply(sum);
253 let count = self.plan_agg(
254 self.scx
255 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
256 expr,
257 vec![],
258 filter,
259 distinct,
260 over,
261 );
262 let result = Self::plan_divide(
263 sum_squares.minus(Self::plan_divide(sum_squared, count.clone())),
264 if sample {
265 count.minus(Expr::number("1"))
266 } else {
267 count
268 },
269 );
270 let result_is_null = Expr::IsExpr {
285 expr: Box::new(result.clone()),
286 construct: IsExprConstruct::Null,
287 negated: false,
288 };
289 Expr::Case {
290 operand: None,
291 conditions: vec![result_is_null],
292 results: vec![Expr::Value(Value::Null)],
293 else_result: Some(Box::new(Expr::HomogenizingFunction {
294 function: HomogenizingFunction::Greatest,
295 exprs: vec![result, Expr::number("0")],
296 })),
297 }
298 }
299
300 fn plan_stddev(
301 &mut self,
302 expr: Expr<Aug>,
303 filter: Option<Box<Expr<Aug>>>,
304 distinct: bool,
305 sample: bool,
306 over: Option<WindowSpec<Aug>>,
307 ) -> Expr<Aug> {
308 self.plan_variance(expr, filter, distinct, sample, over)
309 .call_unary(
310 self.scx
311 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sqrt"]),
312 )
313 }
314
315 fn plan_bool_and(
316 &mut self,
317 expr: Expr<Aug>,
318 filter: Option<Box<Expr<Aug>>>,
319 distinct: bool,
320 over: Option<WindowSpec<Aug>>,
321 ) -> Expr<Aug> {
322 let sum = self.plan_agg(
334 self.scx
335 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
336 expr.negate().cast(self.int32_data_type()),
337 vec![],
338 filter,
339 distinct,
340 over,
341 );
342 sum.equals(Expr::Value(Value::Number(0.to_string())))
343 }
344
345 fn plan_bool_or(
346 &mut self,
347 expr: Expr<Aug>,
348 filter: Option<Box<Expr<Aug>>>,
349 distinct: bool,
350 over: Option<WindowSpec<Aug>>,
351 ) -> Expr<Aug> {
352 let sum = self.plan_agg(
364 self.scx
365 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
366 expr.or(Expr::Value(Value::Boolean(false)))
367 .cast(self.int32_data_type()),
368 vec![],
369 filter,
370 distinct,
371 over,
372 );
373 sum.gt(Expr::Value(Value::Number(0.to_string())))
374 }
375
376 fn rewrite_function(&mut self, func: &Function<Aug>) -> Option<(Ident, Expr<Aug>)> {
377 if let Function {
378 name,
379 args: FunctionArgs::Args { args, order_by: _ },
380 filter,
381 distinct,
382 over,
383 } = func
384 {
385 let pg_catalog_id = self
386 .scx
387 .catalog
388 .resolve_schema(None, PG_CATALOG_SCHEMA)
389 .expect("pg_catalog schema exists")
390 .id();
391 let mz_catalog_id = self
392 .scx
393 .catalog
394 .resolve_schema(None, MZ_CATALOG_SCHEMA)
395 .expect("mz_catalog schema exists")
396 .id();
397 let name = match name {
398 ResolvedItemName::Item {
399 qualifiers,
400 full_name,
401 ..
402 } => {
403 if ![*pg_catalog_id, *mz_catalog_id].contains(&qualifiers.schema_spec) {
404 return None;
405 }
406 full_name.item.clone()
407 }
408 _ => unreachable!(),
409 };
410
411 let filter = filter.clone();
412 let distinct = *distinct;
413 let over = over.clone();
414 let expr = if args.len() == 1 {
415 let arg = args[0].clone();
416 match name.as_str() {
417 "avg_internal_v1" => self.plan_avg_internal_v1(arg, filter, distinct, over),
418 "avg" => self.plan_avg(arg, filter, distinct, over),
419 "variance" | "var_samp" => {
420 self.plan_variance(arg, filter, distinct, true, over)
421 }
422 "var_pop" => self.plan_variance(arg, filter, distinct, false, over),
423 "stddev" | "stddev_samp" => self.plan_stddev(arg, filter, distinct, true, over),
424 "stddev_pop" => self.plan_stddev(arg, filter, distinct, false, over),
425 "bool_and" => self.plan_bool_and(arg, filter, distinct, over),
426 "bool_or" => self.plan_bool_or(arg, filter, distinct, over),
427 _ => return None,
428 }
429 } else if args.len() == 2 {
430 let (lhs, rhs) = (args[0].clone(), args[1].clone());
431 match name.as_str() {
432 "mod" => lhs.modulo(rhs),
433 "pow" => Expr::call(
434 self.scx
435 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "power"]),
436 vec![lhs, rhs],
437 ),
438 _ => return None,
439 }
440 } else {
441 return None;
442 };
443 Some((Ident::new_unchecked(name), expr))
444 } else {
445 None
446 }
447 }
448
449 fn rewrite_expr(&mut self, expr: &Expr<Aug>) -> Option<(Ident, Expr<Aug>)> {
450 match expr {
451 Expr::Function(function) => self.rewrite_function(function),
452 Expr::Identifier(ident) if ident.len() == 1 => {
456 let ident = normalize::ident(ident[0].clone());
457 let fn_ident = match ident.as_str() {
458 "current_role" => Some("current_user"),
459 "current_schema" | "current_timestamp" | "current_user" | "session_user"
460 | "current_catalog" => Some(ident.as_str()),
461 _ => None,
462 };
463 match fn_ident {
464 None => None,
465 Some(fn_ident) => {
466 let expr = Expr::call_nullary(
467 self.scx
468 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, fn_ident]),
469 );
470 Some((Ident::new_unchecked(ident), expr))
471 }
472 }
473 }
474 _ => None,
475 }
476 }
477}
478
479impl<'ast> VisitMut<'ast, Aug> for FuncRewriter<'_> {
480 fn visit_select_item_mut(&mut self, item: &'ast mut SelectItem<Aug>) {
481 if let SelectItem::Expr { expr, alias: None } = item {
482 visit_mut::visit_expr_mut(self, expr);
483 if let Some((alias, expr)) = self.rewrite_expr(expr) {
484 *item = SelectItem::Expr {
485 expr,
486 alias: Some(alias),
487 };
488 }
489 } else {
490 visit_mut::visit_select_item_mut(self, item);
491 }
492 }
493
494 fn visit_table_with_joins_mut(&mut self, item: &'ast mut TableWithJoins<Aug>) {
495 visit_mut::visit_table_with_joins_mut(self, item);
496 match &mut item.relation {
497 TableFactor::Function {
498 function,
499 alias,
500 with_ordinality,
501 } => {
502 self.rewriting_table_factor = true;
503 if let Some((ident, expr)) = self.rewrite_function(function) {
506 let mut select = Select::default().project(SelectItem::Expr {
507 expr,
508 alias: Some(match &alias {
509 Some(TableAlias { name, columns, .. }) => {
510 columns.get(0).unwrap_or(name).clone()
511 }
512 None => ident,
513 }),
514 });
515
516 if *with_ordinality {
517 select = select.project(SelectItem::Expr {
518 expr: Expr::Value(Value::Number("1".into())),
519 alias: Some(ident!(ORDINALITY_COL_NAME)),
520 });
521 }
522
523 item.relation = TableFactor::Derived {
524 lateral: false,
525 subquery: Box::new(Query {
526 ctes: mz_sql_parser::ast::CteBlock::Simple(vec![]),
527 body: mz_sql_parser::ast::SetExpr::Select(Box::new(select)),
528 order_by: vec![],
529 limit: None,
530 offset: None,
531 }),
532 alias: alias.clone(),
533 }
534 }
535 self.rewriting_table_factor = false;
536 }
537 _ => {}
538 }
539 }
540
541 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
542 visit_mut::visit_expr_mut(self, expr);
543 if let Some((_name, new_expr)) = self.rewrite_expr(expr) {
544 *expr = new_expr;
545 }
546 }
547}
548
549struct Desugarer<'a> {
554 scx: &'a StatementContext<'a>,
555 status: Result<(), PlanError>,
556 id_gen: IdGen,
557 recursion_guard: RecursionGuard,
558}
559
560impl<'a> CheckedRecursion for Desugarer<'a> {
561 fn recursion_guard(&self) -> &RecursionGuard {
562 &self.recursion_guard
563 }
564}
565
566impl<'a, 'ast> VisitMut<'ast, Aug> for Desugarer<'a> {
567 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
568 self.visit_internal(Self::visit_expr_mut_internal, expr);
569 }
570}
571
572impl<'a> Desugarer<'a> {
573 fn visit_internal<F, X>(&mut self, f: F, x: X)
574 where
575 F: Fn(&mut Self, X) -> Result<(), PlanError>,
576 {
577 if self.status.is_ok() {
578 let status = self.checked_recur_mut(|d| f(d, x));
581 if self.status.is_ok() {
582 self.status = status;
583 }
584 }
585 }
586
587 fn new(scx: &'a StatementContext) -> Desugarer<'a> {
588 Desugarer {
589 scx,
590 status: Ok(()),
591 id_gen: Default::default(),
592 recursion_guard: RecursionGuard::with_limit(1024), }
594 }
595
596 fn visit_expr_mut_internal(&mut self, expr: &mut Expr<Aug>) -> Result<(), PlanError> {
597 while let Expr::Nested(e) = expr {
599 *expr = e.take();
600 }
601
602 if let Expr::Between {
605 expr: e,
606 low,
607 high,
608 negated,
609 } = expr
610 {
611 if *negated {
612 *expr = Expr::lt(*e.clone(), low.take()).or(e.take().gt(high.take()));
613 } else {
614 *expr = e.clone().gt_eq(low.take()).and(e.take().lt_eq(high.take()));
615 }
616 }
617
618 if let Expr::InList {
628 expr: e,
629 list,
630 negated,
631 } = expr
632 {
633 if let Expr::Row { .. } = &**e {
634 if *negated {
635 *expr = list
636 .drain(..)
637 .map(|r| e.clone().not_equals(r))
638 .reduce(|e1, e2| e1.and(e2))
639 .expect("list known to contain at least one element");
640 } else {
641 *expr = list
642 .drain(..)
643 .map(|r| e.clone().equals(r))
644 .reduce(|e1, e2| e1.or(e2))
645 .expect("list known to contain at least one element");
646 }
647 }
648 }
649
650 if let Expr::InSubquery {
653 expr: e,
654 subquery,
655 negated,
656 } = expr
657 {
658 if *negated {
659 *expr = Expr::AllSubquery {
660 left: Box::new(e.take()),
661 op: Op::bare("<>"),
662 right: Box::new(subquery.take()),
663 };
664 } else {
665 *expr = Expr::AnySubquery {
666 left: Box::new(e.take()),
667 op: Op::bare("="),
668 right: Box::new(subquery.take()),
669 };
670 }
671 }
672
673 if let Expr::AnyExpr { left, op, right } | Expr::AllExpr { left, op, right } = expr {
679 let binding = ident!("elem");
680
681 let subquery = Query::select(
682 Select::default()
683 .from(TableWithJoins {
684 relation: TableFactor::Function {
685 function: Function {
686 name: self
687 .scx
688 .dangerous_resolve_name(vec![MZ_CATALOG_SCHEMA, "unnest"]),
689 args: FunctionArgs::args(vec![right.take()]),
690 filter: None,
691 over: None,
692 distinct: false,
693 },
694 alias: Some(TableAlias {
695 name: ident!("_"),
696 columns: vec![binding.clone()],
697 strict: true,
698 }),
699 with_ordinality: false,
700 },
701 joins: vec![],
702 })
703 .project(SelectItem::Expr {
704 expr: Expr::Identifier(vec![binding]),
705 alias: None,
706 }),
707 );
708
709 let left = Box::new(left.take());
710
711 let op = op.clone();
712
713 *expr = match expr {
714 Expr::AnyExpr { .. } => Expr::AnySubquery {
715 left,
716 op,
717 right: Box::new(subquery),
718 },
719 Expr::AllExpr { .. } => Expr::AllSubquery {
720 left,
721 op,
722 right: Box::new(subquery),
723 },
724 _ => unreachable!(),
725 };
726 }
727
728 if let Expr::AnySubquery { left, op, right } | Expr::AllSubquery { left, op, right } = expr
734 {
735 let left = match &mut **left {
736 Expr::Row { .. } => left.take(),
737 _ => Expr::Row {
738 exprs: vec![left.take()],
739 },
740 };
741
742 let arity = match &left {
743 Expr::Row { exprs } => exprs.len(),
744 _ => unreachable!(),
745 };
746
747 let bindings: Vec<_> = (0..arity)
748 .map(|col| {
751 let unique_id = self.id_gen.allocate_id();
752 Ident::new_unchecked(format!("right_col{col}_{unique_id}"))
753 })
754 .collect();
755
756 let subquery_unique_id = self.id_gen.allocate_id();
757 let subquery_name = Ident::new_unchecked(format!("subquery{subquery_unique_id}"));
759 let select = Select::default()
760 .from(TableWithJoins::subquery(
761 right.take(),
762 TableAlias {
763 name: subquery_name,
764 columns: bindings.clone(),
765 strict: true,
766 },
767 ))
768 .project(SelectItem::Expr {
769 expr: left
770 .binop(
771 op.clone(),
772 Expr::Row {
773 exprs: bindings
774 .into_iter()
775 .map(|b| Expr::Identifier(vec![b]))
776 .collect(),
777 },
778 )
779 .call_unary(self.scx.dangerous_resolve_name(match expr {
780 Expr::AnySubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_any"],
781 Expr::AllSubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_all"],
782 _ => unreachable!(),
783 })),
784 alias: None,
785 });
786
787 *expr = Expr::Subquery(Box::new(Query::select(select)));
788 }
789
790 if let Expr::Op {
806 op,
807 expr1: left,
808 expr2: Some(right),
809 } = expr
810 {
811 if let (Expr::Row { exprs: left }, Expr::Row { exprs: right }) =
812 (&mut **left, &mut **right)
813 {
814 if matches!(normalize::op(op)?, "=" | "<>" | "<" | "<=" | ">" | ">=") {
815 if left.len() != right.len() {
816 sql_bail!("unequal number of entries in row expressions");
817 }
818 if left.is_empty() {
819 assert!(right.is_empty());
820 sql_bail!("cannot compare rows of zero length");
821 }
822 }
823 match normalize::op(op)? {
824 "=" | "<>" => {
825 let mut pairs = left.iter_mut().zip_eq(right);
826 let mut new = pairs
827 .next()
828 .map(|(l, r)| l.take().equals(r.take()))
829 .expect("cannot compare rows of zero length");
830 for (l, r) in pairs {
831 new = l.take().equals(r.take()).and(new);
832 }
833 if normalize::op(op)? == "<>" {
834 new = new.negate();
835 }
836 *expr = new;
837 }
838 "<" | "<=" | ">" | ">=" => {
839 let strict_op = match normalize::op(op)? {
840 "<" | "<=" => "<",
841 ">" | ">=" => ">",
842 _ => unreachable!(),
843 };
844 let (l, r) = (left.last_mut().unwrap(), right.last_mut().unwrap());
845 let mut new = l.take().binop(op.clone(), r.take());
846 for (l, r) in left
847 .iter_mut()
848 .rev()
849 .zip_eq(right.into_iter().rev())
850 .skip(1)
851 {
852 new = l
853 .clone()
854 .binop(Op::bare(strict_op), r.clone())
855 .or(l.take().equals(r.take()).and(new));
856 }
857 *expr = new;
858 }
859 _ if left.len() == 1 && right.len() == 1 => {
860 let left = left.remove(0);
861 let right = right.remove(0);
862 *expr = left.binop(op.clone(), right);
863 }
864 _ => (),
865 }
866 }
867 }
868
869 visit_mut::visit_expr_mut(self, expr);
870 Ok(())
871 }
872}