1use mz_ore::id_gen::IdGen;
17use mz_ore::stack::{CheckedRecursion, RecursionGuard};
18use mz_repr::namespaces::{MZ_CATALOG_SCHEMA, MZ_UNSAFE_SCHEMA, PG_CATALOG_SCHEMA};
19use mz_sql_parser::ast::visit_mut::{self, VisitMut, VisitMutNode};
20use mz_sql_parser::ast::{
21 Expr, Function, FunctionArgs, HomogenizingFunction, Ident, IsExprConstruct, Op, OrderByExpr,
22 Query, Select, SelectItem, TableAlias, TableFactor, TableWithJoins, Value, WindowSpec,
23};
24use mz_sql_parser::ident;
25
26use crate::names::{Aug, PartialItemName, ResolvedDataType, ResolvedItemName};
27use crate::normalize;
28use crate::plan::{PlanError, StatementContext};
29
30pub(crate) fn transform<N>(scx: &StatementContext, node: &mut N) -> Result<(), PlanError>
31where
32 N: for<'a> VisitMutNode<'a, Aug>,
33{
34 let mut func_rewriter = FuncRewriter::new(scx);
35 node.visit_mut(&mut func_rewriter);
36 func_rewriter.status?;
37
38 let mut desugarer = Desugarer::new(scx);
39 node.visit_mut(&mut desugarer);
40 desugarer.status
41}
42
43struct FuncRewriter<'a> {
64 scx: &'a StatementContext<'a>,
65 status: Result<(), PlanError>,
66 rewriting_table_factor: bool,
67}
68
69impl<'a> FuncRewriter<'a> {
70 fn new(scx: &'a StatementContext<'a>) -> FuncRewriter<'a> {
71 FuncRewriter {
72 scx,
73 status: Ok(()),
74 rewriting_table_factor: false,
75 }
76 }
77
78 fn resolve_known_valid_data_type(&self, name: &PartialItemName) -> ResolvedDataType {
79 let item = self
80 .scx
81 .catalog
82 .resolve_type(name)
83 .expect("data type known to be valid");
84 let full_name = self.scx.catalog.resolve_full_name(item.name());
85 ResolvedDataType::Named {
86 id: item.id(),
87 qualifiers: item.name().qualifiers.clone(),
88 full_name,
89 modifiers: vec![],
90 print_id: true,
91 }
92 }
93
94 fn int32_data_type(&self) -> ResolvedDataType {
95 self.resolve_known_valid_data_type(&PartialItemName {
96 database: None,
97 schema: Some(PG_CATALOG_SCHEMA.into()),
98 item: "int4".into(),
99 })
100 }
101
102 fn plan_divide(lhs: Expr<Aug>, rhs: Expr<Aug>) -> Expr<Aug> {
105 lhs.divide(Expr::Case {
106 operand: None,
107 conditions: vec![rhs.clone().equals(Expr::number("0"))],
108 results: vec![Expr::null()],
109 else_result: Some(Box::new(rhs)),
110 })
111 }
112
113 fn plan_agg(
114 &mut self,
115 name: ResolvedItemName,
116 expr: Expr<Aug>,
117 order_by: Vec<OrderByExpr<Aug>>,
118 filter: Option<Box<Expr<Aug>>>,
119 distinct: bool,
120 over: Option<WindowSpec<Aug>>,
121 ) -> Expr<Aug> {
122 if self.rewriting_table_factor && self.status.is_ok() {
123 self.status = Err(PlanError::Unstructured(
124 "aggregate functions are not supported in functions in FROM".to_string(),
125 ))
126 }
127 Expr::Function(Function {
128 name,
129 args: FunctionArgs::Args {
130 args: vec![expr],
131 order_by,
132 },
133 filter,
134 over,
135 distinct,
136 })
137 }
138
139 fn plan_avg(
140 &mut self,
141 expr: Expr<Aug>,
142 filter: Option<Box<Expr<Aug>>>,
143 distinct: bool,
144 over: Option<WindowSpec<Aug>>,
145 ) -> Expr<Aug> {
146 let sum = self
147 .plan_agg(
148 self.scx
149 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
150 expr.clone(),
151 vec![],
152 filter.clone(),
153 distinct,
154 over.clone(),
155 )
156 .call_unary(
157 self.scx
158 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
159 );
160 let count = self.plan_agg(
161 self.scx
162 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
163 expr,
164 vec![],
165 filter,
166 distinct,
167 over,
168 );
169 Self::plan_divide(sum, count)
170 }
171
172 fn plan_avg_internal_v1(
174 &mut self,
175 expr: Expr<Aug>,
176 filter: Option<Box<Expr<Aug>>>,
177 distinct: bool,
178 over: Option<WindowSpec<Aug>>,
179 ) -> Expr<Aug> {
180 let sum = self
181 .plan_agg(
182 self.scx
183 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
184 expr.clone(),
185 vec![],
186 filter.clone(),
187 distinct,
188 over.clone(),
189 )
190 .call_unary(
191 self.scx
192 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion_internal_v1"]),
193 );
194 let count = self.plan_agg(
195 self.scx
196 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
197 expr,
198 vec![],
199 filter,
200 distinct,
201 over,
202 );
203 Self::plan_divide(sum, count)
204 }
205
206 fn plan_variance(
207 &mut self,
208 expr: Expr<Aug>,
209 filter: Option<Box<Expr<Aug>>>,
210 distinct: bool,
211 sample: bool,
212 over: Option<WindowSpec<Aug>>,
213 ) -> Expr<Aug> {
214 let expr = expr.call_unary(
229 self.scx
230 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
231 );
232 let expr_squared = expr.clone().multiply(expr.clone());
233 let sum_squares = self.plan_agg(
234 self.scx
235 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
236 expr_squared,
237 vec![],
238 filter.clone(),
239 distinct,
240 over.clone(),
241 );
242 let sum = self.plan_agg(
243 self.scx
244 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
245 expr.clone(),
246 vec![],
247 filter.clone(),
248 distinct,
249 over.clone(),
250 );
251 let sum_squared = sum.clone().multiply(sum);
252 let count = self.plan_agg(
253 self.scx
254 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
255 expr,
256 vec![],
257 filter,
258 distinct,
259 over,
260 );
261 let result = Self::plan_divide(
262 sum_squares.minus(Self::plan_divide(sum_squared, count.clone())),
263 if sample {
264 count.minus(Expr::number("1"))
265 } else {
266 count
267 },
268 );
269 let result_is_null = Expr::IsExpr {
284 expr: Box::new(result.clone()),
285 construct: IsExprConstruct::Null,
286 negated: false,
287 };
288 Expr::Case {
289 operand: None,
290 conditions: vec![result_is_null],
291 results: vec![Expr::Value(Value::Null)],
292 else_result: Some(Box::new(Expr::HomogenizingFunction {
293 function: HomogenizingFunction::Greatest,
294 exprs: vec![result, Expr::number("0")],
295 })),
296 }
297 }
298
299 fn plan_stddev(
300 &mut self,
301 expr: Expr<Aug>,
302 filter: Option<Box<Expr<Aug>>>,
303 distinct: bool,
304 sample: bool,
305 over: Option<WindowSpec<Aug>>,
306 ) -> Expr<Aug> {
307 self.plan_variance(expr, filter, distinct, sample, over)
308 .call_unary(
309 self.scx
310 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sqrt"]),
311 )
312 }
313
314 fn plan_bool_and(
315 &mut self,
316 expr: Expr<Aug>,
317 filter: Option<Box<Expr<Aug>>>,
318 distinct: bool,
319 over: Option<WindowSpec<Aug>>,
320 ) -> Expr<Aug> {
321 let sum = self.plan_agg(
333 self.scx
334 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
335 expr.negate().cast(self.int32_data_type()),
336 vec![],
337 filter,
338 distinct,
339 over,
340 );
341 sum.equals(Expr::Value(Value::Number(0.to_string())))
342 }
343
344 fn plan_bool_or(
345 &mut self,
346 expr: Expr<Aug>,
347 filter: Option<Box<Expr<Aug>>>,
348 distinct: bool,
349 over: Option<WindowSpec<Aug>>,
350 ) -> Expr<Aug> {
351 let sum = self.plan_agg(
363 self.scx
364 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
365 expr.or(Expr::Value(Value::Boolean(false)))
366 .cast(self.int32_data_type()),
367 vec![],
368 filter,
369 distinct,
370 over,
371 );
372 sum.gt(Expr::Value(Value::Number(0.to_string())))
373 }
374
375 fn rewrite_function(&mut self, func: &Function<Aug>) -> Option<(Ident, Expr<Aug>)> {
376 if let Function {
377 name,
378 args: FunctionArgs::Args { args, order_by: _ },
379 filter,
380 distinct,
381 over,
382 } = func
383 {
384 let pg_catalog_id = self
385 .scx
386 .catalog
387 .resolve_schema(None, PG_CATALOG_SCHEMA)
388 .expect("pg_catalog schema exists")
389 .id();
390 let mz_catalog_id = self
391 .scx
392 .catalog
393 .resolve_schema(None, MZ_CATALOG_SCHEMA)
394 .expect("mz_catalog schema exists")
395 .id();
396 let name = match name {
397 ResolvedItemName::Item {
398 qualifiers,
399 full_name,
400 ..
401 } => {
402 if ![*pg_catalog_id, *mz_catalog_id].contains(&qualifiers.schema_spec) {
403 return None;
404 }
405 full_name.item.clone()
406 }
407 _ => unreachable!(),
408 };
409
410 let filter = filter.clone();
411 let distinct = *distinct;
412 let over = over.clone();
413 let expr = if args.len() == 1 {
414 let arg = args[0].clone();
415 match name.as_str() {
416 "avg_internal_v1" => self.plan_avg_internal_v1(arg, filter, distinct, over),
417 "avg" => self.plan_avg(arg, filter, distinct, over),
418 "variance" | "var_samp" => {
419 self.plan_variance(arg, filter, distinct, true, over)
420 }
421 "var_pop" => self.plan_variance(arg, filter, distinct, false, over),
422 "stddev" | "stddev_samp" => self.plan_stddev(arg, filter, distinct, true, over),
423 "stddev_pop" => self.plan_stddev(arg, filter, distinct, false, over),
424 "bool_and" => self.plan_bool_and(arg, filter, distinct, over),
425 "bool_or" => self.plan_bool_or(arg, filter, distinct, over),
426 _ => return None,
427 }
428 } else if args.len() == 2 {
429 let (lhs, rhs) = (args[0].clone(), args[1].clone());
430 match name.as_str() {
431 "mod" => lhs.modulo(rhs),
432 "pow" => Expr::call(
433 self.scx
434 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "power"]),
435 vec![lhs, rhs],
436 ),
437 _ => return None,
438 }
439 } else {
440 return None;
441 };
442 Some((Ident::new_unchecked(name), expr))
443 } else {
444 None
445 }
446 }
447
448 fn rewrite_expr(&mut self, expr: &Expr<Aug>) -> Option<(Ident, Expr<Aug>)> {
449 match expr {
450 Expr::Function(function) => self.rewrite_function(function),
451 Expr::Identifier(ident) if ident.len() == 1 => {
455 let ident = normalize::ident(ident[0].clone());
456 let fn_ident = match ident.as_str() {
457 "current_role" => Some("current_user"),
458 "current_schema" | "current_timestamp" | "current_user" | "session_user" => {
459 Some(ident.as_str())
460 }
461 _ => None,
462 };
463 match fn_ident {
464 None => None,
465 Some(fn_ident) => {
466 let expr = Expr::call_nullary(
467 self.scx
468 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, fn_ident]),
469 );
470 Some((Ident::new_unchecked(ident), expr))
471 }
472 }
473 }
474 _ => None,
475 }
476 }
477}
478
479impl<'ast> VisitMut<'ast, Aug> for FuncRewriter<'_> {
480 fn visit_select_item_mut(&mut self, item: &'ast mut SelectItem<Aug>) {
481 if let SelectItem::Expr { expr, alias: None } = item {
482 visit_mut::visit_expr_mut(self, expr);
483 if let Some((alias, expr)) = self.rewrite_expr(expr) {
484 *item = SelectItem::Expr {
485 expr,
486 alias: Some(alias),
487 };
488 }
489 } else {
490 visit_mut::visit_select_item_mut(self, item);
491 }
492 }
493
494 fn visit_table_with_joins_mut(&mut self, item: &'ast mut TableWithJoins<Aug>) {
495 visit_mut::visit_table_with_joins_mut(self, item);
496 match &mut item.relation {
497 TableFactor::Function {
498 function,
499 alias,
500 with_ordinality,
501 } => {
502 self.rewriting_table_factor = true;
503 if let Some((ident, expr)) = self.rewrite_function(function) {
506 let mut select = Select::default().project(SelectItem::Expr {
507 expr,
508 alias: Some(match &alias {
509 Some(TableAlias { name, columns, .. }) => {
510 columns.get(0).unwrap_or(name).clone()
511 }
512 None => ident,
513 }),
514 });
515
516 if *with_ordinality {
517 select = select.project(SelectItem::Expr {
518 expr: Expr::Value(Value::Number("1".into())),
519 alias: Some(ident!("ordinality")),
520 });
521 }
522
523 item.relation = TableFactor::Derived {
524 lateral: false,
525 subquery: Box::new(Query {
526 ctes: mz_sql_parser::ast::CteBlock::Simple(vec![]),
527 body: mz_sql_parser::ast::SetExpr::Select(Box::new(select)),
528 order_by: vec![],
529 limit: None,
530 offset: None,
531 }),
532 alias: alias.clone(),
533 }
534 }
535 self.rewriting_table_factor = false;
536 }
537 _ => {}
538 }
539 }
540
541 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
542 visit_mut::visit_expr_mut(self, expr);
543 if let Some((_name, new_expr)) = self.rewrite_expr(expr) {
544 *expr = new_expr;
545 }
546 }
547}
548
549struct Desugarer<'a> {
554 scx: &'a StatementContext<'a>,
555 status: Result<(), PlanError>,
556 id_gen: IdGen,
557 recursion_guard: RecursionGuard,
558}
559
560impl<'a> CheckedRecursion for Desugarer<'a> {
561 fn recursion_guard(&self) -> &RecursionGuard {
562 &self.recursion_guard
563 }
564}
565
566impl<'a, 'ast> VisitMut<'ast, Aug> for Desugarer<'a> {
567 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
568 self.visit_internal(Self::visit_expr_mut_internal, expr);
569 }
570}
571
572impl<'a> Desugarer<'a> {
573 fn visit_internal<F, X>(&mut self, f: F, x: X)
574 where
575 F: Fn(&mut Self, X) -> Result<(), PlanError>,
576 {
577 if self.status.is_ok() {
578 let status = self.checked_recur_mut(|d| f(d, x));
581 if self.status.is_ok() {
582 self.status = status;
583 }
584 }
585 }
586
587 fn new(scx: &'a StatementContext) -> Desugarer<'a> {
588 Desugarer {
589 scx,
590 status: Ok(()),
591 id_gen: Default::default(),
592 recursion_guard: RecursionGuard::with_limit(1024), }
594 }
595
596 fn visit_expr_mut_internal(&mut self, expr: &mut Expr<Aug>) -> Result<(), PlanError> {
597 while let Expr::Nested(e) = expr {
599 *expr = e.take();
600 }
601
602 if let Expr::Between {
605 expr: e,
606 low,
607 high,
608 negated,
609 } = expr
610 {
611 if *negated {
612 *expr = Expr::lt(*e.clone(), low.take()).or(e.take().gt(high.take()));
613 } else {
614 *expr = e.clone().gt_eq(low.take()).and(e.take().lt_eq(high.take()));
615 }
616 }
617
618 if let Expr::InList {
628 expr: e,
629 list,
630 negated,
631 } = expr
632 {
633 if let Expr::Row { .. } = &**e {
634 if *negated {
635 *expr = list
636 .drain(..)
637 .map(|r| e.clone().not_equals(r))
638 .reduce(|e1, e2| e1.and(e2))
639 .expect("list known to contain at least one element");
640 } else {
641 *expr = list
642 .drain(..)
643 .map(|r| e.clone().equals(r))
644 .reduce(|e1, e2| e1.or(e2))
645 .expect("list known to contain at least one element");
646 }
647 }
648 }
649
650 if let Expr::InSubquery {
653 expr: e,
654 subquery,
655 negated,
656 } = expr
657 {
658 if *negated {
659 *expr = Expr::AllSubquery {
660 left: Box::new(e.take()),
661 op: Op::bare("<>"),
662 right: Box::new(subquery.take()),
663 };
664 } else {
665 *expr = Expr::AnySubquery {
666 left: Box::new(e.take()),
667 op: Op::bare("="),
668 right: Box::new(subquery.take()),
669 };
670 }
671 }
672
673 if let Expr::AnyExpr { left, op, right } | Expr::AllExpr { left, op, right } = expr {
679 let binding = ident!("elem");
680
681 let subquery = Query::select(
682 Select::default()
683 .from(TableWithJoins {
684 relation: TableFactor::Function {
685 function: Function {
686 name: self
687 .scx
688 .dangerous_resolve_name(vec![MZ_CATALOG_SCHEMA, "unnest"]),
689 args: FunctionArgs::args(vec![right.take()]),
690 filter: None,
691 over: None,
692 distinct: false,
693 },
694 alias: Some(TableAlias {
695 name: ident!("_"),
696 columns: vec![binding.clone()],
697 strict: true,
698 }),
699 with_ordinality: false,
700 },
701 joins: vec![],
702 })
703 .project(SelectItem::Expr {
704 expr: Expr::Identifier(vec![binding]),
705 alias: None,
706 }),
707 );
708
709 let left = Box::new(left.take());
710
711 let op = op.clone();
712
713 *expr = match expr {
714 Expr::AnyExpr { .. } => Expr::AnySubquery {
715 left,
716 op,
717 right: Box::new(subquery),
718 },
719 Expr::AllExpr { .. } => Expr::AllSubquery {
720 left,
721 op,
722 right: Box::new(subquery),
723 },
724 _ => unreachable!(),
725 };
726 }
727
728 if let Expr::AnySubquery { left, op, right } | Expr::AllSubquery { left, op, right } = expr
734 {
735 let left = match &mut **left {
736 Expr::Row { .. } => left.take(),
737 _ => Expr::Row {
738 exprs: vec![left.take()],
739 },
740 };
741
742 let arity = match &left {
743 Expr::Row { exprs } => exprs.len(),
744 _ => unreachable!(),
745 };
746
747 let bindings: Vec<_> = (0..arity)
748 .map(|col| {
751 let unique_id = self.id_gen.allocate_id();
752 Ident::new_unchecked(format!("right_col{col}_{unique_id}"))
753 })
754 .collect();
755
756 let subquery_unique_id = self.id_gen.allocate_id();
757 let subquery_name = Ident::new_unchecked(format!("subquery{subquery_unique_id}"));
759 let select = Select::default()
760 .from(TableWithJoins::subquery(
761 right.take(),
762 TableAlias {
763 name: subquery_name,
764 columns: bindings.clone(),
765 strict: true,
766 },
767 ))
768 .project(SelectItem::Expr {
769 expr: left
770 .binop(
771 op.clone(),
772 Expr::Row {
773 exprs: bindings
774 .into_iter()
775 .map(|b| Expr::Identifier(vec![b]))
776 .collect(),
777 },
778 )
779 .call_unary(self.scx.dangerous_resolve_name(match expr {
780 Expr::AnySubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_any"],
781 Expr::AllSubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_all"],
782 _ => unreachable!(),
783 })),
784 alias: None,
785 });
786
787 *expr = Expr::Subquery(Box::new(Query::select(select)));
788 }
789
790 if let Expr::Op {
806 op,
807 expr1: left,
808 expr2: Some(right),
809 } = expr
810 {
811 if let (Expr::Row { exprs: left }, Expr::Row { exprs: right }) =
812 (&mut **left, &mut **right)
813 {
814 if matches!(normalize::op(op)?, "=" | "<>" | "<" | "<=" | ">" | ">=") {
815 if left.len() != right.len() {
816 sql_bail!("unequal number of entries in row expressions");
817 }
818 if left.is_empty() {
819 assert!(right.is_empty());
820 sql_bail!("cannot compare rows of zero length");
821 }
822 }
823 match normalize::op(op)? {
824 "=" | "<>" => {
825 let mut pairs = left.iter_mut().zip(right);
826 let mut new = pairs
827 .next()
828 .map(|(l, r)| l.take().equals(r.take()))
829 .expect("cannot compare rows of zero length");
830 for (l, r) in pairs {
831 new = l.take().equals(r.take()).and(new);
832 }
833 if normalize::op(op)? == "<>" {
834 new = new.negate();
835 }
836 *expr = new;
837 }
838 "<" | "<=" | ">" | ">=" => {
839 let strict_op = match normalize::op(op)? {
840 "<" | "<=" => "<",
841 ">" | ">=" => ">",
842 _ => unreachable!(),
843 };
844 let (l, r) = (left.last_mut().unwrap(), right.last_mut().unwrap());
845 let mut new = l.take().binop(op.clone(), r.take());
846 for (l, r) in left.iter_mut().zip(right).rev().skip(1) {
847 new = l
848 .clone()
849 .binop(Op::bare(strict_op), r.clone())
850 .or(l.take().equals(r.take()).and(new));
851 }
852 *expr = new;
853 }
854 _ if left.len() == 1 && right.len() == 1 => {
855 let left = left.remove(0);
856 let right = right.remove(0);
857 *expr = left.binop(op.clone(), right);
858 }
859 _ => (),
860 }
861 }
862 }
863
864 visit_mut::visit_expr_mut(self, expr);
865 Ok(())
866 }
867}