1use mz_ore::stack::{CheckedRecursion, RecursionGuard};
17use mz_repr::namespaces::{MZ_CATALOG_SCHEMA, MZ_UNSAFE_SCHEMA, PG_CATALOG_SCHEMA};
18use mz_sql_parser::ast::visit_mut::{self, VisitMut, VisitMutNode};
19use mz_sql_parser::ast::{
20 Expr, Function, FunctionArgs, HomogenizingFunction, Ident, IsExprConstruct, Op, OrderByExpr,
21 Query, Select, SelectItem, TableAlias, TableFactor, TableWithJoins, Value, WindowSpec,
22};
23use mz_sql_parser::ident;
24use uuid::Uuid;
25
26use crate::names::{Aug, PartialItemName, ResolvedDataType, ResolvedItemName};
27use crate::normalize;
28use crate::plan::{PlanError, StatementContext};
29
30pub(crate) fn transform<N>(scx: &StatementContext, node: &mut N) -> Result<(), PlanError>
31where
32 N: for<'a> VisitMutNode<'a, Aug>,
33{
34 let mut func_rewriter = FuncRewriter::new(scx);
35 node.visit_mut(&mut func_rewriter);
36 func_rewriter.status?;
37
38 let mut desugarer = Desugarer::new(scx);
39 node.visit_mut(&mut desugarer);
40 desugarer.status
41}
42
43struct FuncRewriter<'a> {
64 scx: &'a StatementContext<'a>,
65 status: Result<(), PlanError>,
66 rewriting_table_factor: bool,
67}
68
69impl<'a> FuncRewriter<'a> {
70 fn new(scx: &'a StatementContext<'a>) -> FuncRewriter<'a> {
71 FuncRewriter {
72 scx,
73 status: Ok(()),
74 rewriting_table_factor: false,
75 }
76 }
77
78 fn resolve_known_valid_data_type(&self, name: &PartialItemName) -> ResolvedDataType {
79 let item = self
80 .scx
81 .catalog
82 .resolve_type(name)
83 .expect("data type known to be valid");
84 let full_name = self.scx.catalog.resolve_full_name(item.name());
85 ResolvedDataType::Named {
86 id: item.id(),
87 qualifiers: item.name().qualifiers.clone(),
88 full_name,
89 modifiers: vec![],
90 print_id: true,
91 }
92 }
93
94 fn int32_data_type(&self) -> ResolvedDataType {
95 self.resolve_known_valid_data_type(&PartialItemName {
96 database: None,
97 schema: Some(PG_CATALOG_SCHEMA.into()),
98 item: "int4".into(),
99 })
100 }
101
102 fn plan_divide(lhs: Expr<Aug>, rhs: Expr<Aug>) -> Expr<Aug> {
105 lhs.divide(Expr::Case {
106 operand: None,
107 conditions: vec![rhs.clone().equals(Expr::number("0"))],
108 results: vec![Expr::null()],
109 else_result: Some(Box::new(rhs)),
110 })
111 }
112
113 fn plan_agg(
114 &mut self,
115 name: ResolvedItemName,
116 expr: Expr<Aug>,
117 order_by: Vec<OrderByExpr<Aug>>,
118 filter: Option<Box<Expr<Aug>>>,
119 distinct: bool,
120 over: Option<WindowSpec<Aug>>,
121 ) -> Expr<Aug> {
122 if self.rewriting_table_factor && self.status.is_ok() {
123 self.status = Err(PlanError::Unstructured(
124 "aggregate functions are not supported in functions in FROM".to_string(),
125 ))
126 }
127 Expr::Function(Function {
128 name,
129 args: FunctionArgs::Args {
130 args: vec![expr],
131 order_by,
132 },
133 filter,
134 over,
135 distinct,
136 })
137 }
138
139 fn plan_avg(
140 &mut self,
141 expr: Expr<Aug>,
142 filter: Option<Box<Expr<Aug>>>,
143 distinct: bool,
144 over: Option<WindowSpec<Aug>>,
145 ) -> Expr<Aug> {
146 let sum = self
147 .plan_agg(
148 self.scx
149 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
150 expr.clone(),
151 vec![],
152 filter.clone(),
153 distinct,
154 over.clone(),
155 )
156 .call_unary(
157 self.scx
158 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
159 );
160 let count = self.plan_agg(
161 self.scx
162 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
163 expr,
164 vec![],
165 filter,
166 distinct,
167 over,
168 );
169 Self::plan_divide(sum, count)
170 }
171
172 fn plan_avg_internal_v1(
174 &mut self,
175 expr: Expr<Aug>,
176 filter: Option<Box<Expr<Aug>>>,
177 distinct: bool,
178 over: Option<WindowSpec<Aug>>,
179 ) -> Expr<Aug> {
180 let sum = self
181 .plan_agg(
182 self.scx
183 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
184 expr.clone(),
185 vec![],
186 filter.clone(),
187 distinct,
188 over.clone(),
189 )
190 .call_unary(
191 self.scx
192 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion_internal_v1"]),
193 );
194 let count = self.plan_agg(
195 self.scx
196 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
197 expr,
198 vec![],
199 filter,
200 distinct,
201 over,
202 );
203 Self::plan_divide(sum, count)
204 }
205
206 fn plan_variance(
207 &mut self,
208 expr: Expr<Aug>,
209 filter: Option<Box<Expr<Aug>>>,
210 distinct: bool,
211 sample: bool,
212 over: Option<WindowSpec<Aug>>,
213 ) -> Expr<Aug> {
214 let expr = expr.call_unary(
229 self.scx
230 .dangerous_resolve_name(vec![MZ_UNSAFE_SCHEMA, "mz_avg_promotion"]),
231 );
232 let expr_squared = expr.clone().multiply(expr.clone());
233 let sum_squares = self.plan_agg(
234 self.scx
235 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
236 expr_squared,
237 vec![],
238 filter.clone(),
239 distinct,
240 over.clone(),
241 );
242 let sum = self.plan_agg(
243 self.scx
244 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
245 expr.clone(),
246 vec![],
247 filter.clone(),
248 distinct,
249 over.clone(),
250 );
251 let sum_squared = sum.clone().multiply(sum);
252 let count = self.plan_agg(
253 self.scx
254 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "count"]),
255 expr,
256 vec![],
257 filter,
258 distinct,
259 over,
260 );
261 let result = Self::plan_divide(
262 sum_squares.minus(Self::plan_divide(sum_squared, count.clone())),
263 if sample {
264 count.minus(Expr::number("1"))
265 } else {
266 count
267 },
268 );
269 let result_is_null = Expr::IsExpr {
284 expr: Box::new(result.clone()),
285 construct: IsExprConstruct::Null,
286 negated: false,
287 };
288 Expr::Case {
289 operand: None,
290 conditions: vec![result_is_null],
291 results: vec![Expr::Value(Value::Null)],
292 else_result: Some(Box::new(Expr::HomogenizingFunction {
293 function: HomogenizingFunction::Greatest,
294 exprs: vec![result, Expr::number("0")],
295 })),
296 }
297 }
298
299 fn plan_stddev(
300 &mut self,
301 expr: Expr<Aug>,
302 filter: Option<Box<Expr<Aug>>>,
303 distinct: bool,
304 sample: bool,
305 over: Option<WindowSpec<Aug>>,
306 ) -> Expr<Aug> {
307 self.plan_variance(expr, filter, distinct, sample, over)
308 .call_unary(
309 self.scx
310 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sqrt"]),
311 )
312 }
313
314 fn plan_bool_and(
315 &mut self,
316 expr: Expr<Aug>,
317 filter: Option<Box<Expr<Aug>>>,
318 distinct: bool,
319 over: Option<WindowSpec<Aug>>,
320 ) -> Expr<Aug> {
321 let sum = self.plan_agg(
333 self.scx
334 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
335 expr.negate().cast(self.int32_data_type()),
336 vec![],
337 filter,
338 distinct,
339 over,
340 );
341 sum.equals(Expr::Value(Value::Number(0.to_string())))
342 }
343
344 fn plan_bool_or(
345 &mut self,
346 expr: Expr<Aug>,
347 filter: Option<Box<Expr<Aug>>>,
348 distinct: bool,
349 over: Option<WindowSpec<Aug>>,
350 ) -> Expr<Aug> {
351 let sum = self.plan_agg(
363 self.scx
364 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "sum"]),
365 expr.or(Expr::Value(Value::Boolean(false)))
366 .cast(self.int32_data_type()),
367 vec![],
368 filter,
369 distinct,
370 over,
371 );
372 sum.gt(Expr::Value(Value::Number(0.to_string())))
373 }
374
375 fn rewrite_function(&mut self, func: &Function<Aug>) -> Option<(Ident, Expr<Aug>)> {
376 if let Function {
377 name,
378 args: FunctionArgs::Args { args, order_by: _ },
379 filter,
380 distinct,
381 over,
382 } = func
383 {
384 let pg_catalog_id = self
385 .scx
386 .catalog
387 .resolve_schema(None, PG_CATALOG_SCHEMA)
388 .expect("pg_catalog schema exists")
389 .id();
390 let mz_catalog_id = self
391 .scx
392 .catalog
393 .resolve_schema(None, MZ_CATALOG_SCHEMA)
394 .expect("mz_catalog schema exists")
395 .id();
396 let name = match name {
397 ResolvedItemName::Item {
398 qualifiers,
399 full_name,
400 ..
401 } => {
402 if ![*pg_catalog_id, *mz_catalog_id].contains(&qualifiers.schema_spec) {
403 return None;
404 }
405 full_name.item.clone()
406 }
407 _ => unreachable!(),
408 };
409
410 let filter = filter.clone();
411 let distinct = *distinct;
412 let over = over.clone();
413 let expr = if args.len() == 1 {
414 let arg = args[0].clone();
415 match name.as_str() {
416 "avg_internal_v1" => self.plan_avg_internal_v1(arg, filter, distinct, over),
417 "avg" => self.plan_avg(arg, filter, distinct, over),
418 "variance" | "var_samp" => {
419 self.plan_variance(arg, filter, distinct, true, over)
420 }
421 "var_pop" => self.plan_variance(arg, filter, distinct, false, over),
422 "stddev" | "stddev_samp" => self.plan_stddev(arg, filter, distinct, true, over),
423 "stddev_pop" => self.plan_stddev(arg, filter, distinct, false, over),
424 "bool_and" => self.plan_bool_and(arg, filter, distinct, over),
425 "bool_or" => self.plan_bool_or(arg, filter, distinct, over),
426 _ => return None,
427 }
428 } else if args.len() == 2 {
429 let (lhs, rhs) = (args[0].clone(), args[1].clone());
430 match name.as_str() {
431 "mod" => lhs.modulo(rhs),
432 "pow" => Expr::call(
433 self.scx
434 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, "power"]),
435 vec![lhs, rhs],
436 ),
437 _ => return None,
438 }
439 } else {
440 return None;
441 };
442 Some((Ident::new_unchecked(name), expr))
443 } else {
444 None
445 }
446 }
447
448 fn rewrite_expr(&mut self, expr: &Expr<Aug>) -> Option<(Ident, Expr<Aug>)> {
449 match expr {
450 Expr::Function(function) => self.rewrite_function(function),
451 Expr::Identifier(ident) if ident.len() == 1 => {
455 let ident = normalize::ident(ident[0].clone());
456 let fn_ident = match ident.as_str() {
457 "current_role" => Some("current_user"),
458 "current_schema" | "current_timestamp" | "current_user" | "session_user" => {
459 Some(ident.as_str())
460 }
461 _ => None,
462 };
463 match fn_ident {
464 None => None,
465 Some(fn_ident) => {
466 let expr = Expr::call_nullary(
467 self.scx
468 .dangerous_resolve_name(vec![PG_CATALOG_SCHEMA, fn_ident]),
469 );
470 Some((Ident::new_unchecked(ident), expr))
471 }
472 }
473 }
474 _ => None,
475 }
476 }
477}
478
479impl<'ast> VisitMut<'ast, Aug> for FuncRewriter<'_> {
480 fn visit_select_item_mut(&mut self, item: &'ast mut SelectItem<Aug>) {
481 if let SelectItem::Expr { expr, alias: None } = item {
482 visit_mut::visit_expr_mut(self, expr);
483 if let Some((alias, expr)) = self.rewrite_expr(expr) {
484 *item = SelectItem::Expr {
485 expr,
486 alias: Some(alias),
487 };
488 }
489 } else {
490 visit_mut::visit_select_item_mut(self, item);
491 }
492 }
493
494 fn visit_table_with_joins_mut(&mut self, item: &'ast mut TableWithJoins<Aug>) {
495 visit_mut::visit_table_with_joins_mut(self, item);
496 match &mut item.relation {
497 TableFactor::Function {
498 function,
499 alias,
500 with_ordinality,
501 } => {
502 self.rewriting_table_factor = true;
503 if let Some((ident, expr)) = self.rewrite_function(function) {
506 let mut select = Select::default().project(SelectItem::Expr {
507 expr,
508 alias: Some(match &alias {
509 Some(TableAlias { name, columns, .. }) => {
510 columns.get(0).unwrap_or(name).clone()
511 }
512 None => ident,
513 }),
514 });
515
516 if *with_ordinality {
517 select = select.project(SelectItem::Expr {
518 expr: Expr::Value(Value::Number("1".into())),
519 alias: Some(ident!("ordinality")),
520 });
521 }
522
523 item.relation = TableFactor::Derived {
524 lateral: false,
525 subquery: Box::new(Query {
526 ctes: mz_sql_parser::ast::CteBlock::Simple(vec![]),
527 body: mz_sql_parser::ast::SetExpr::Select(Box::new(select)),
528 order_by: vec![],
529 limit: None,
530 offset: None,
531 }),
532 alias: alias.clone(),
533 }
534 }
535 self.rewriting_table_factor = false;
536 }
537 _ => {}
538 }
539 }
540
541 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
542 visit_mut::visit_expr_mut(self, expr);
543 if let Some((_name, new_expr)) = self.rewrite_expr(expr) {
544 *expr = new_expr;
545 }
546 }
547}
548
549struct Desugarer<'a> {
554 scx: &'a StatementContext<'a>,
555 status: Result<(), PlanError>,
556 recursion_guard: RecursionGuard,
557}
558
559impl<'a> CheckedRecursion for Desugarer<'a> {
560 fn recursion_guard(&self) -> &RecursionGuard {
561 &self.recursion_guard
562 }
563}
564
565impl<'a, 'ast> VisitMut<'ast, Aug> for Desugarer<'a> {
566 fn visit_expr_mut(&mut self, expr: &'ast mut Expr<Aug>) {
567 self.visit_internal(Self::visit_expr_mut_internal, expr);
568 }
569}
570
571impl<'a> Desugarer<'a> {
572 fn visit_internal<F, X>(&mut self, f: F, x: X)
573 where
574 F: Fn(&mut Self, X) -> Result<(), PlanError>,
575 {
576 if self.status.is_ok() {
577 let status = self.checked_recur_mut(|d| f(d, x));
580 if self.status.is_ok() {
581 self.status = status;
582 }
583 }
584 }
585
586 fn new(scx: &'a StatementContext) -> Desugarer<'a> {
587 Desugarer {
588 scx,
589 status: Ok(()),
590 recursion_guard: RecursionGuard::with_limit(1024), }
592 }
593
594 fn visit_expr_mut_internal(&mut self, expr: &mut Expr<Aug>) -> Result<(), PlanError> {
595 while let Expr::Nested(e) = expr {
597 *expr = e.take();
598 }
599
600 if let Expr::Between {
603 expr: e,
604 low,
605 high,
606 negated,
607 } = expr
608 {
609 if *negated {
610 *expr = Expr::lt(*e.clone(), low.take()).or(e.take().gt(high.take()));
611 } else {
612 *expr = e.clone().gt_eq(low.take()).and(e.take().lt_eq(high.take()));
613 }
614 }
615
616 if let Expr::InList {
626 expr: e,
627 list,
628 negated,
629 } = expr
630 {
631 if let Expr::Row { .. } = &**e {
632 if *negated {
633 *expr = list
634 .drain(..)
635 .map(|r| e.clone().not_equals(r))
636 .reduce(|e1, e2| e1.and(e2))
637 .expect("list known to contain at least one element");
638 } else {
639 *expr = list
640 .drain(..)
641 .map(|r| e.clone().equals(r))
642 .reduce(|e1, e2| e1.or(e2))
643 .expect("list known to contain at least one element");
644 }
645 }
646 }
647
648 if let Expr::InSubquery {
651 expr: e,
652 subquery,
653 negated,
654 } = expr
655 {
656 if *negated {
657 *expr = Expr::AllSubquery {
658 left: Box::new(e.take()),
659 op: Op::bare("<>"),
660 right: Box::new(subquery.take()),
661 };
662 } else {
663 *expr = Expr::AnySubquery {
664 left: Box::new(e.take()),
665 op: Op::bare("="),
666 right: Box::new(subquery.take()),
667 };
668 }
669 }
670
671 if let Expr::AnyExpr { left, op, right } | Expr::AllExpr { left, op, right } = expr {
677 let binding = ident!("elem");
678
679 let subquery = Query::select(
680 Select::default()
681 .from(TableWithJoins {
682 relation: TableFactor::Function {
683 function: Function {
684 name: self
685 .scx
686 .dangerous_resolve_name(vec![MZ_CATALOG_SCHEMA, "unnest"]),
687 args: FunctionArgs::args(vec![right.take()]),
688 filter: None,
689 over: None,
690 distinct: false,
691 },
692 alias: Some(TableAlias {
693 name: ident!("_"),
694 columns: vec![binding.clone()],
695 strict: true,
696 }),
697 with_ordinality: false,
698 },
699 joins: vec![],
700 })
701 .project(SelectItem::Expr {
702 expr: Expr::Identifier(vec![binding]),
703 alias: None,
704 }),
705 );
706
707 let left = Box::new(left.take());
708
709 let op = op.clone();
710
711 *expr = match expr {
712 Expr::AnyExpr { .. } => Expr::AnySubquery {
713 left,
714 op,
715 right: Box::new(subquery),
716 },
717 Expr::AllExpr { .. } => Expr::AllSubquery {
718 left,
719 op,
720 right: Box::new(subquery),
721 },
722 _ => unreachable!(),
723 };
724 }
725
726 if let Expr::AnySubquery { left, op, right } | Expr::AllSubquery { left, op, right } = expr
732 {
733 let left = match &mut **left {
734 Expr::Row { .. } => left.take(),
735 _ => Expr::Row {
736 exprs: vec![left.take()],
737 },
738 };
739
740 let arity = match &left {
741 Expr::Row { exprs } => exprs.len(),
742 _ => unreachable!(),
743 };
744
745 let bindings: Vec<_> = (0..arity)
746 .map(|_| Ident::new_unchecked(format!("right_{}", Uuid::new_v4())))
749 .collect();
750
751 let select = Select::default()
752 .from(TableWithJoins::subquery(
753 right.take(),
754 TableAlias {
755 name: ident!("subquery"),
756 columns: bindings.clone(),
757 strict: true,
758 },
759 ))
760 .project(SelectItem::Expr {
761 expr: left
762 .binop(
763 op.clone(),
764 Expr::Row {
765 exprs: bindings
766 .into_iter()
767 .map(|b| Expr::Identifier(vec![b]))
768 .collect(),
769 },
770 )
771 .call_unary(self.scx.dangerous_resolve_name(match expr {
772 Expr::AnySubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_any"],
773 Expr::AllSubquery { .. } => vec![MZ_UNSAFE_SCHEMA, "mz_all"],
774 _ => unreachable!(),
775 })),
776 alias: None,
777 });
778
779 *expr = Expr::Subquery(Box::new(Query::select(select)));
780 }
781
782 if let Expr::Op {
798 op,
799 expr1: left,
800 expr2: Some(right),
801 } = expr
802 {
803 if let (Expr::Row { exprs: left }, Expr::Row { exprs: right }) =
804 (&mut **left, &mut **right)
805 {
806 if matches!(normalize::op(op)?, "=" | "<>" | "<" | "<=" | ">" | ">=") {
807 if left.len() != right.len() {
808 sql_bail!("unequal number of entries in row expressions");
809 }
810 if left.is_empty() {
811 assert!(right.is_empty());
812 sql_bail!("cannot compare rows of zero length");
813 }
814 }
815 match normalize::op(op)? {
816 "=" | "<>" => {
817 let mut pairs = left.iter_mut().zip(right);
818 let mut new = pairs
819 .next()
820 .map(|(l, r)| l.take().equals(r.take()))
821 .expect("cannot compare rows of zero length");
822 for (l, r) in pairs {
823 new = l.take().equals(r.take()).and(new);
824 }
825 if normalize::op(op)? == "<>" {
826 new = new.negate();
827 }
828 *expr = new;
829 }
830 "<" | "<=" | ">" | ">=" => {
831 let strict_op = match normalize::op(op)? {
832 "<" | "<=" => "<",
833 ">" | ">=" => ">",
834 _ => unreachable!(),
835 };
836 let (l, r) = (left.last_mut().unwrap(), right.last_mut().unwrap());
837 let mut new = l.take().binop(op.clone(), r.take());
838 for (l, r) in left.iter_mut().zip(right).rev().skip(1) {
839 new = l
840 .clone()
841 .binop(Op::bare(strict_op), r.clone())
842 .or(l.take().equals(r.take()).and(new));
843 }
844 *expr = new;
845 }
846 _ if left.len() == 1 && right.len() == 1 => {
847 let left = left.remove(0);
848 let right = right.remove(0);
849 *expr = left.binop(op.clone(), right);
850 }
851 _ => (),
852 }
853 }
854 }
855
856 visit_mut::visit_expr_mut(self, expr);
857 Ok(())
858 }
859}