1use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25use mz_expr::func::variadic::{And, Or};
26pub use mz_expr::{
27 BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
28};
29use mz_ore::collections::CollectionExt;
30use mz_ore::error::ErrorExt;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::separated;
33use mz_ore::treat_as_equal::TreatAsEqual;
34use mz_ore::{soft_assert_or_log, stack};
35use mz_repr::adt::array::ArrayDimension;
36use mz_repr::adt::numeric::NumericMaxScale;
37use mz_repr::*;
38use serde::{Deserialize, Serialize};
39
40use crate::plan::error::PlanError;
41use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
42use crate::plan::typeconv::{self, CastContext, plan_cast};
43use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
44
45use super::plan_utils::GroupSizeHints;
46
47#[allow(missing_debug_implementations)]
48pub struct Hir;
49
50impl IR for Hir {
51 type Relation = HirRelationExpr;
52 type Scalar = HirScalarExpr;
53}
54
55impl AlgExcept for Hir {
56 fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
57 if *all {
58 let rhs = rhs.negate();
59 HirRelationExpr::union(lhs, rhs).threshold()
60 } else {
61 let lhs = lhs.distinct();
62 let rhs = rhs.distinct().negate();
63 HirRelationExpr::union(lhs, rhs).threshold()
64 }
65 }
66
67 fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
68 let mut result = None;
69
70 use HirRelationExpr::*;
71 if let Threshold { input } = expr {
72 if let Union { base: lhs, inputs } = input.as_ref() {
73 if let [rhs] = &inputs[..] {
74 if let Negate { input: rhs } = rhs {
75 match (lhs.as_ref(), rhs.as_ref()) {
76 (Distinct { input: lhs }, Distinct { input: rhs }) => {
77 let all = false;
78 let lhs = lhs.as_ref();
79 let rhs = rhs.as_ref();
80 result = Some(Except { all, lhs, rhs })
81 }
82 (lhs, rhs) => {
83 let all = true;
84 result = Some(Except { all, lhs, rhs })
85 }
86 }
87 }
88 }
89 }
90 }
91
92 result
93 }
94}
95
96#[derive(
97 Debug,
98 Clone,
99 PartialEq,
100 Eq,
101 PartialOrd,
102 Ord,
103 Hash,
104 Serialize,
105 Deserialize
106)]
107pub enum HirRelationExpr {
109 Constant {
110 rows: Vec<Row>,
111 typ: SqlRelationType,
112 },
113 Get {
114 id: mz_expr::Id,
115 typ: SqlRelationType,
116 },
117 LetRec {
119 limit: Option<LetRecLimit>,
121 bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
123 body: Box<HirRelationExpr>,
125 },
126 Let {
128 name: String,
129 id: mz_expr::LocalId,
131 value: Box<HirRelationExpr>,
133 body: Box<HirRelationExpr>,
135 },
136 Project {
137 input: Box<HirRelationExpr>,
138 outputs: Vec<usize>,
139 },
140 Map {
141 input: Box<HirRelationExpr>,
142 scalars: Vec<HirScalarExpr>,
143 },
144 CallTable {
145 func: TableFunc,
146 exprs: Vec<HirScalarExpr>,
147 },
148 Filter {
149 input: Box<HirRelationExpr>,
150 predicates: Vec<HirScalarExpr>,
151 },
152 Join {
155 left: Box<HirRelationExpr>,
156 right: Box<HirRelationExpr>,
157 on: HirScalarExpr,
158 kind: JoinKind,
159 },
160 Reduce {
164 input: Box<HirRelationExpr>,
165 group_key: Vec<usize>,
166 aggregates: Vec<AggregateExpr>,
167 expected_group_size: Option<u64>,
168 },
169 Distinct {
170 input: Box<HirRelationExpr>,
171 },
172 TopK {
174 input: Box<HirRelationExpr>,
176 group_key: Vec<usize>,
178 order_key: Vec<ColumnOrder>,
180 limit: Option<HirScalarExpr>,
189 offset: HirScalarExpr,
194 expected_group_size: Option<u64>,
196 },
197 Negate {
198 input: Box<HirRelationExpr>,
199 },
200 Threshold {
202 input: Box<HirRelationExpr>,
203 },
204 Union {
205 base: Box<HirRelationExpr>,
206 inputs: Vec<HirRelationExpr>,
207 },
208}
209
210pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
212
213#[derive(
214 Debug,
215 Clone,
216 PartialEq,
217 Eq,
218 PartialOrd,
219 Ord,
220 Hash,
221 Serialize,
222 Deserialize
223)]
224pub enum HirScalarExpr {
226 Column(ColumnRef, NameMetadata),
230 Parameter(usize, NameMetadata),
231 Literal(Row, SqlColumnType, NameMetadata),
232 CallUnmaterializable(UnmaterializableFunc, NameMetadata),
233 CallUnary {
234 func: UnaryFunc,
235 expr: Box<HirScalarExpr>,
236 name: NameMetadata,
237 },
238 CallBinary {
239 func: BinaryFunc,
240 expr1: Box<HirScalarExpr>,
241 expr2: Box<HirScalarExpr>,
242 name: NameMetadata,
243 },
244 CallVariadic {
245 func: VariadicFunc,
246 exprs: Vec<HirScalarExpr>,
247 name: NameMetadata,
248 },
249 If {
250 cond: Box<HirScalarExpr>,
251 then: Box<HirScalarExpr>,
252 els: Box<HirScalarExpr>,
253 name: NameMetadata,
254 },
255 Exists(Box<HirRelationExpr>, NameMetadata),
257 Select(Box<HirRelationExpr>, NameMetadata),
262 Windowing(WindowExpr, NameMetadata),
263}
264
265#[derive(
266 Debug,
267 Clone,
268 PartialEq,
269 Eq,
270 PartialOrd,
271 Ord,
272 Hash,
273 Serialize,
274 Deserialize
275)]
276pub struct WindowExpr {
279 pub func: WindowExprType,
280 pub partition_by: Vec<HirScalarExpr>,
281 pub order_by: Vec<HirScalarExpr>,
292}
293
294impl WindowExpr {
295 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
296 where
297 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
298 {
299 #[allow(deprecated)]
300 self.func.visit_expressions(f)?;
301 for expr in self.partition_by.iter() {
302 f(expr)?;
303 }
304 for expr in self.order_by.iter() {
305 f(expr)?;
306 }
307 Ok(())
308 }
309
310 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
311 where
312 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
313 {
314 #[allow(deprecated)]
315 self.func.visit_expressions_mut(f)?;
316 for expr in self.partition_by.iter_mut() {
317 f(expr)?;
318 }
319 for expr in self.order_by.iter_mut() {
320 f(expr)?;
321 }
322 Ok(())
323 }
324}
325
326impl VisitChildren<HirScalarExpr> for WindowExpr {
327 fn visit_children<F>(&self, mut f: F)
328 where
329 F: FnMut(&HirScalarExpr),
330 {
331 self.func.visit_children(&mut f);
332 for expr in self.partition_by.iter() {
333 f(expr);
334 }
335 for expr in self.order_by.iter() {
336 f(expr);
337 }
338 }
339
340 fn visit_mut_children<F>(&mut self, mut f: F)
341 where
342 F: FnMut(&mut HirScalarExpr),
343 {
344 self.func.visit_mut_children(&mut f);
345 for expr in self.partition_by.iter_mut() {
346 f(expr);
347 }
348 for expr in self.order_by.iter_mut() {
349 f(expr);
350 }
351 }
352
353 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
354 where
355 F: FnMut(&HirScalarExpr) -> Result<(), E>,
356 E: From<RecursionLimitError>,
357 {
358 self.func.try_visit_children(&mut f)?;
359 for expr in self.partition_by.iter() {
360 f(expr)?;
361 }
362 for expr in self.order_by.iter() {
363 f(expr)?;
364 }
365 Ok(())
366 }
367
368 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
369 where
370 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
371 E: From<RecursionLimitError>,
372 {
373 self.func.try_visit_mut_children(&mut f)?;
374 for expr in self.partition_by.iter_mut() {
375 f(expr)?;
376 }
377 for expr in self.order_by.iter_mut() {
378 f(expr)?;
379 }
380 Ok(())
381 }
382}
383
384#[derive(
385 Debug,
386 Clone,
387 PartialEq,
388 Eq,
389 PartialOrd,
390 Ord,
391 Hash,
392 Serialize,
393 Deserialize
394)]
395pub enum WindowExprType {
412 Scalar(ScalarWindowExpr),
413 Value(ValueWindowExpr),
414 Aggregate(AggregateWindowExpr),
415}
416
417impl WindowExprType {
418 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
419 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
420 where
421 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
422 {
423 #[allow(deprecated)]
424 match self {
425 Self::Scalar(expr) => expr.visit_expressions(f),
426 Self::Value(expr) => expr.visit_expressions(f),
427 Self::Aggregate(expr) => expr.visit_expressions(f),
428 }
429 }
430
431 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
432 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
433 where
434 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
435 {
436 #[allow(deprecated)]
437 match self {
438 Self::Scalar(expr) => expr.visit_expressions_mut(f),
439 Self::Value(expr) => expr.visit_expressions_mut(f),
440 Self::Aggregate(expr) => expr.visit_expressions_mut(f),
441 }
442 }
443
444 fn typ(
445 &self,
446 outers: &[SqlRelationType],
447 inner: &SqlRelationType,
448 params: &BTreeMap<usize, SqlScalarType>,
449 ) -> SqlColumnType {
450 match self {
451 Self::Scalar(expr) => expr.typ(outers, inner, params),
452 Self::Value(expr) => expr.typ(outers, inner, params),
453 Self::Aggregate(expr) => expr.typ(outers, inner, params),
454 }
455 }
456}
457
458impl VisitChildren<HirScalarExpr> for WindowExprType {
459 fn visit_children<F>(&self, f: F)
460 where
461 F: FnMut(&HirScalarExpr),
462 {
463 match self {
464 Self::Scalar(_) => (),
465 Self::Value(expr) => expr.visit_children(f),
466 Self::Aggregate(expr) => expr.visit_children(f),
467 }
468 }
469
470 fn visit_mut_children<F>(&mut self, f: F)
471 where
472 F: FnMut(&mut HirScalarExpr),
473 {
474 match self {
475 Self::Scalar(_) => (),
476 Self::Value(expr) => expr.visit_mut_children(f),
477 Self::Aggregate(expr) => expr.visit_mut_children(f),
478 }
479 }
480
481 fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
482 where
483 F: FnMut(&HirScalarExpr) -> Result<(), E>,
484 E: From<RecursionLimitError>,
485 {
486 match self {
487 Self::Scalar(_) => Ok(()),
488 Self::Value(expr) => expr.try_visit_children(f),
489 Self::Aggregate(expr) => expr.try_visit_children(f),
490 }
491 }
492
493 fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
494 where
495 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
496 E: From<RecursionLimitError>,
497 {
498 match self {
499 Self::Scalar(_) => Ok(()),
500 Self::Value(expr) => expr.try_visit_mut_children(f),
501 Self::Aggregate(expr) => expr.try_visit_mut_children(f),
502 }
503 }
504}
505
506#[derive(
507 Debug,
508 Clone,
509 PartialEq,
510 Eq,
511 PartialOrd,
512 Ord,
513 Hash,
514 Serialize,
515 Deserialize
516)]
517pub struct ScalarWindowExpr {
518 pub func: ScalarWindowFunc,
519 pub order_by: Vec<ColumnOrder>,
520}
521
522impl ScalarWindowExpr {
523 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
524 pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
525 where
526 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
527 {
528 match self.func {
529 ScalarWindowFunc::RowNumber => {}
530 ScalarWindowFunc::Rank => {}
531 ScalarWindowFunc::DenseRank => {}
532 }
533 Ok(())
534 }
535
536 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
537 pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
538 where
539 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
540 {
541 match self.func {
542 ScalarWindowFunc::RowNumber => {}
543 ScalarWindowFunc::Rank => {}
544 ScalarWindowFunc::DenseRank => {}
545 }
546 Ok(())
547 }
548
549 fn typ(
550 &self,
551 _outers: &[SqlRelationType],
552 _inner: &SqlRelationType,
553 _params: &BTreeMap<usize, SqlScalarType>,
554 ) -> SqlColumnType {
555 self.func.output_sql_type()
556 }
557
558 pub fn into_expr(self) -> mz_expr::AggregateFunc {
559 match self.func {
560 ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
561 order_by: self.order_by,
562 },
563 ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
564 order_by: self.order_by,
565 },
566 ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
567 order_by: self.order_by,
568 },
569 }
570 }
571}
572
573#[derive(
574 Debug,
575 Clone,
576 PartialEq,
577 Eq,
578 PartialOrd,
579 Ord,
580 Hash,
581 Serialize,
582 Deserialize
583)]
584pub enum ScalarWindowFunc {
586 RowNumber,
587 Rank,
588 DenseRank,
589}
590
591impl Display for ScalarWindowFunc {
592 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
593 match self {
594 ScalarWindowFunc::RowNumber => write!(f, "row_number"),
595 ScalarWindowFunc::Rank => write!(f, "rank"),
596 ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
597 }
598 }
599}
600
601impl ScalarWindowFunc {
602 pub fn output_sql_type(&self) -> SqlColumnType {
603 match self {
604 ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
605 ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
606 ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
607 }
608 }
609}
610
611#[derive(
612 Debug,
613 Clone,
614 PartialEq,
615 Eq,
616 PartialOrd,
617 Ord,
618 Hash,
619 Serialize,
620 Deserialize
621)]
622pub struct ValueWindowExpr {
623 pub func: ValueWindowFunc,
624 pub args: Box<HirScalarExpr>,
630 pub order_by: Vec<ColumnOrder>,
632 pub window_frame: WindowFrame,
633 pub ignore_nulls: bool,
634}
635
636impl Display for ValueWindowFunc {
637 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
638 match self {
639 ValueWindowFunc::Lag => write!(f, "lag"),
640 ValueWindowFunc::Lead => write!(f, "lead"),
641 ValueWindowFunc::FirstValue => write!(f, "first_value"),
642 ValueWindowFunc::LastValue => write!(f, "last_value"),
643 ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
644 }
645 }
646}
647
648impl ValueWindowExpr {
649 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
650 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
651 where
652 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
653 {
654 f(&self.args)
655 }
656
657 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
658 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
659 where
660 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
661 {
662 f(&mut self.args)
663 }
664
665 fn typ(
666 &self,
667 outers: &[SqlRelationType],
668 inner: &SqlRelationType,
669 params: &BTreeMap<usize, SqlScalarType>,
670 ) -> SqlColumnType {
671 self.func
672 .output_sql_type(self.args.typ(outers, inner, params))
673 }
674
675 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
677 (
678 self.args,
679 self.func
680 .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
681 )
682 }
683}
684
685impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
686 fn visit_children<F>(&self, mut f: F)
687 where
688 F: FnMut(&HirScalarExpr),
689 {
690 f(&self.args)
691 }
692
693 fn visit_mut_children<F>(&mut self, mut f: F)
694 where
695 F: FnMut(&mut HirScalarExpr),
696 {
697 f(&mut self.args)
698 }
699
700 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
701 where
702 F: FnMut(&HirScalarExpr) -> Result<(), E>,
703 E: From<RecursionLimitError>,
704 {
705 f(&self.args)
706 }
707
708 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
709 where
710 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
711 E: From<RecursionLimitError>,
712 {
713 f(&mut self.args)
714 }
715}
716
717#[derive(
718 Debug,
719 Clone,
720 PartialEq,
721 Eq,
722 PartialOrd,
723 Ord,
724 Hash,
725 Serialize,
726 Deserialize
727)]
728pub enum ValueWindowFunc {
730 Lag,
731 Lead,
732 FirstValue,
733 LastValue,
734 Fused(Vec<ValueWindowFunc>),
735}
736
737impl ValueWindowFunc {
738 pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
739 match self {
740 ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
741 input_type.scalar_type.unwrap_record_element_type()[0]
743 .clone()
744 .nullable(true)
745 }
746 ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
747 input_type.scalar_type.nullable(true)
748 }
749 ValueWindowFunc::Fused(funcs) => {
750 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
751 SqlScalarType::Record {
752 fields: funcs
753 .iter()
754 .zip_eq(input_types)
755 .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
756 .collect(),
757 custom_id: None,
758 }
759 .nullable(false)
760 }
761 }
762 }
763
764 pub fn into_expr(
765 self,
766 order_by: Vec<ColumnOrder>,
767 window_frame: WindowFrame,
768 ignore_nulls: bool,
769 ) -> mz_expr::AggregateFunc {
770 match self {
771 ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
773 order_by,
774 lag_lead: mz_expr::LagLeadType::Lag,
775 ignore_nulls,
776 },
777 ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
778 order_by,
779 lag_lead: mz_expr::LagLeadType::Lead,
780 ignore_nulls,
781 },
782 ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
783 order_by,
784 window_frame,
785 },
786 ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
787 order_by,
788 window_frame,
789 },
790 ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
791 funcs: funcs
792 .into_iter()
793 .map(|func| {
794 func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
795 })
796 .collect(),
797 order_by,
798 },
799 }
800 }
801}
802
803#[derive(
804 Debug,
805 Clone,
806 PartialEq,
807 Eq,
808 PartialOrd,
809 Ord,
810 Hash,
811 Serialize,
812 Deserialize
813)]
814pub struct AggregateWindowExpr {
815 pub aggregate_expr: AggregateExpr,
816 pub order_by: Vec<ColumnOrder>,
817 pub window_frame: WindowFrame,
818}
819
820impl AggregateWindowExpr {
821 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
822 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
823 where
824 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
825 {
826 f(&self.aggregate_expr.expr)
827 }
828
829 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
830 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
831 where
832 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
833 {
834 f(&mut self.aggregate_expr.expr)
835 }
836
837 fn typ(
838 &self,
839 outers: &[SqlRelationType],
840 inner: &SqlRelationType,
841 params: &BTreeMap<usize, SqlScalarType>,
842 ) -> SqlColumnType {
843 self.aggregate_expr
844 .func
845 .output_sql_type(self.aggregate_expr.expr.typ(outers, inner, params))
846 }
847
848 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
849 if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
850 (
851 self.aggregate_expr.expr,
852 FusedWindowAggregate {
853 wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
854 order_by: self.order_by,
855 window_frame: self.window_frame,
856 },
857 )
858 } else {
859 (
860 self.aggregate_expr.expr,
861 WindowAggregate {
862 wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
863 order_by: self.order_by,
864 window_frame: self.window_frame,
865 },
866 )
867 }
868 }
869}
870
871impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
872 fn visit_children<F>(&self, mut f: F)
873 where
874 F: FnMut(&HirScalarExpr),
875 {
876 f(&self.aggregate_expr.expr)
877 }
878
879 fn visit_mut_children<F>(&mut self, mut f: F)
880 where
881 F: FnMut(&mut HirScalarExpr),
882 {
883 f(&mut self.aggregate_expr.expr)
884 }
885
886 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
887 where
888 F: FnMut(&HirScalarExpr) -> Result<(), E>,
889 E: From<RecursionLimitError>,
890 {
891 f(&self.aggregate_expr.expr)
892 }
893
894 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
895 where
896 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
897 E: From<RecursionLimitError>,
898 {
899 f(&mut self.aggregate_expr.expr)
900 }
901}
902
903#[derive(Clone, Debug)]
928pub enum CoercibleScalarExpr {
929 Coerced(HirScalarExpr),
930 Parameter(usize),
931 LiteralNull,
932 LiteralString(String),
933 LiteralRecord(Vec<CoercibleScalarExpr>),
934}
935
936impl CoercibleScalarExpr {
937 pub fn type_as(
938 self,
939 ecx: &ExprContext,
940 ty: &SqlScalarType,
941 ) -> Result<HirScalarExpr, PlanError> {
942 let expr = typeconv::plan_coerce(ecx, self, ty)?;
943 let expr_ty = ecx.scalar_type(&expr);
944 if ty != &expr_ty {
945 sql_bail!(
946 "{} must have type {}, not type {}",
947 ecx.name,
948 ecx.humanize_sql_scalar_type(ty, false),
949 ecx.humanize_sql_scalar_type(&expr_ty, false),
950 );
951 }
952 Ok(expr)
953 }
954
955 pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
956 typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
957 }
958
959 pub fn cast_to(
960 self,
961 ecx: &ExprContext,
962 ccx: CastContext,
963 ty: &SqlScalarType,
964 ) -> Result<HirScalarExpr, PlanError> {
965 let expr = typeconv::plan_coerce(ecx, self, ty)?;
966 typeconv::plan_cast(ecx, ccx, expr, ty)
967 }
968}
969
970#[derive(Clone, Debug)]
972pub enum CoercibleColumnType {
973 Coerced(SqlColumnType),
974 Record(Vec<CoercibleColumnType>),
975 Uncoerced,
976}
977
978impl CoercibleColumnType {
979 pub fn nullable(&self) -> bool {
981 match self {
982 CoercibleColumnType::Coerced(ct) => ct.nullable,
984
985 CoercibleColumnType::Record(_) => false,
987
988 CoercibleColumnType::Uncoerced => true,
991 }
992 }
993}
994
995#[derive(Clone, Debug)]
997pub enum CoercibleScalarType {
998 Coerced(SqlScalarType),
999 Record(Vec<CoercibleColumnType>),
1000 Uncoerced,
1001}
1002
1003impl CoercibleScalarType {
1004 pub fn is_coerced(&self) -> bool {
1006 matches!(self, CoercibleScalarType::Coerced(_))
1007 }
1008
1009 pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1011 match self {
1012 CoercibleScalarType::Coerced(t) => Some(t),
1013 _ => None,
1014 }
1015 }
1016
1017 pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1020 where
1021 F: FnOnce(SqlScalarType) -> SqlScalarType,
1022 {
1023 match self {
1024 CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1025 _ => self,
1026 }
1027 }
1028
1029 pub fn force_coerced_if_record(&mut self) {
1036 fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1037 let mut fields = vec![];
1038 for (i, uf) in uncoerced_fields.enumerate() {
1039 let name = ColumnName::from(format!("f{}", i + 1));
1040 let ty = match uf {
1041 CoercibleColumnType::Coerced(ty) => ty,
1042 CoercibleColumnType::Record(mut fields) => {
1043 convert(fields.drain(..)).nullable(false)
1044 }
1045 CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1046 };
1047 fields.push((name, ty))
1048 }
1049 SqlScalarType::Record {
1050 fields: fields.into(),
1051 custom_id: None,
1052 }
1053 }
1054
1055 if let CoercibleScalarType::Record(fields) = self {
1056 *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1057 }
1058 }
1059}
1060
1061pub trait AbstractExpr {
1065 type Type: AbstractColumnType;
1066
1067 fn typ(
1069 &self,
1070 outers: &[SqlRelationType],
1071 inner: &SqlRelationType,
1072 params: &BTreeMap<usize, SqlScalarType>,
1073 ) -> Self::Type;
1074}
1075
1076impl AbstractExpr for CoercibleScalarExpr {
1077 type Type = CoercibleColumnType;
1078
1079 fn typ(
1080 &self,
1081 outers: &[SqlRelationType],
1082 inner: &SqlRelationType,
1083 params: &BTreeMap<usize, SqlScalarType>,
1084 ) -> Self::Type {
1085 match self {
1086 CoercibleScalarExpr::Coerced(expr) => {
1087 CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1088 }
1089 CoercibleScalarExpr::LiteralRecord(scalars) => {
1090 let fields = scalars
1091 .iter()
1092 .map(|s| s.typ(outers, inner, params))
1093 .collect();
1094 CoercibleColumnType::Record(fields)
1095 }
1096 _ => CoercibleColumnType::Uncoerced,
1097 }
1098 }
1099}
1100
1101pub trait AbstractColumnType {
1106 type AbstractScalarType;
1107
1108 fn scalar_type(self) -> Self::AbstractScalarType;
1111}
1112
1113impl AbstractColumnType for SqlColumnType {
1114 type AbstractScalarType = SqlScalarType;
1115
1116 fn scalar_type(self) -> Self::AbstractScalarType {
1117 self.scalar_type
1118 }
1119}
1120
1121impl AbstractColumnType for CoercibleColumnType {
1122 type AbstractScalarType = CoercibleScalarType;
1123
1124 fn scalar_type(self) -> Self::AbstractScalarType {
1125 match self {
1126 CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1127 CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1128 CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1129 }
1130 }
1131}
1132
1133impl From<HirScalarExpr> for CoercibleScalarExpr {
1134 fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1135 CoercibleScalarExpr::Coerced(expr)
1136 }
1137}
1138
1139#[derive(
1154 Debug,
1155 Clone,
1156 Copy,
1157 PartialEq,
1158 Eq,
1159 Hash,
1160 Ord,
1161 PartialOrd,
1162 Serialize,
1163 Deserialize
1164)]
1165pub struct ColumnRef {
1166 pub level: usize,
1168 pub column: usize,
1170}
1171
1172#[derive(
1173 Debug,
1174 Clone,
1175 PartialEq,
1176 Eq,
1177 PartialOrd,
1178 Ord,
1179 Hash,
1180 Serialize,
1181 Deserialize
1182)]
1183pub enum JoinKind {
1184 Inner,
1185 LeftOuter,
1186 RightOuter,
1187 FullOuter,
1188}
1189
1190impl fmt::Display for JoinKind {
1191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1192 write!(
1193 f,
1194 "{}",
1195 match self {
1196 JoinKind::Inner => "Inner",
1197 JoinKind::LeftOuter => "LeftOuter",
1198 JoinKind::RightOuter => "RightOuter",
1199 JoinKind::FullOuter => "FullOuter",
1200 }
1201 )
1202 }
1203}
1204
1205impl JoinKind {
1206 pub fn can_be_correlated(&self) -> bool {
1207 match self {
1208 JoinKind::Inner | JoinKind::LeftOuter => true,
1209 JoinKind::RightOuter | JoinKind::FullOuter => false,
1210 }
1211 }
1212
1213 pub fn can_elide_identity_left_join(&self) -> bool {
1214 match self {
1215 JoinKind::Inner | JoinKind::RightOuter => true,
1216 JoinKind::LeftOuter | JoinKind::FullOuter => false,
1217 }
1218 }
1219
1220 pub fn can_elide_identity_right_join(&self) -> bool {
1221 match self {
1222 JoinKind::Inner | JoinKind::LeftOuter => true,
1223 JoinKind::RightOuter | JoinKind::FullOuter => false,
1224 }
1225 }
1226}
1227
1228#[derive(
1229 Debug,
1230 Clone,
1231 PartialEq,
1232 Eq,
1233 PartialOrd,
1234 Ord,
1235 Hash,
1236 Serialize,
1237 Deserialize
1238)]
1239pub struct AggregateExpr {
1240 pub func: AggregateFunc,
1241 pub expr: Box<HirScalarExpr>,
1242 pub distinct: bool,
1243}
1244
1245#[derive(
1253 Clone,
1254 Debug,
1255 Eq,
1256 PartialEq,
1257 PartialOrd,
1258 Ord,
1259 Hash,
1260 Serialize,
1261 Deserialize
1262)]
1263pub enum AggregateFunc {
1264 MaxNumeric,
1265 MaxInt16,
1266 MaxInt32,
1267 MaxInt64,
1268 MaxUInt16,
1269 MaxUInt32,
1270 MaxUInt64,
1271 MaxMzTimestamp,
1272 MaxFloat32,
1273 MaxFloat64,
1274 MaxBool,
1275 MaxString,
1276 MaxDate,
1277 MaxTimestamp,
1278 MaxTimestampTz,
1279 MaxInterval,
1280 MaxTime,
1281 MinNumeric,
1282 MinInt16,
1283 MinInt32,
1284 MinInt64,
1285 MinUInt16,
1286 MinUInt32,
1287 MinUInt64,
1288 MinMzTimestamp,
1289 MinFloat32,
1290 MinFloat64,
1291 MinBool,
1292 MinString,
1293 MinDate,
1294 MinTimestamp,
1295 MinTimestampTz,
1296 MinInterval,
1297 MinTime,
1298 SumInt16,
1299 SumInt32,
1300 SumInt64,
1301 SumUInt16,
1302 SumUInt32,
1303 SumUInt64,
1304 SumFloat32,
1305 SumFloat64,
1306 SumNumeric,
1307 Count,
1308 Any,
1309 All,
1310 JsonbAgg {
1317 order_by: Vec<ColumnOrder>,
1318 },
1319 JsonbObjectAgg {
1322 order_by: Vec<ColumnOrder>,
1323 },
1324 MapAgg {
1328 order_by: Vec<ColumnOrder>,
1329 value_type: SqlScalarType,
1330 },
1331 ArrayConcat {
1334 order_by: Vec<ColumnOrder>,
1335 },
1336 ListConcat {
1339 order_by: Vec<ColumnOrder>,
1340 },
1341 StringAgg {
1342 order_by: Vec<ColumnOrder>,
1343 },
1344 FusedWindowAgg {
1350 funcs: Vec<AggregateFunc>,
1351 },
1352 Dummy,
1357}
1358
1359impl AggregateFunc {
1360 pub fn into_expr(self) -> mz_expr::AggregateFunc {
1362 match self {
1363 AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1364 AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1365 AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1366 AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1367 AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1368 AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1369 AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1370 AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1371 AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1372 AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1373 AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1374 AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1375 AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1376 AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1377 AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1378 AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1379 AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1380 AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1381 AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1382 AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1383 AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1384 AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1385 AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1386 AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1387 AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1388 AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1389 AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1390 AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1391 AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1392 AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1393 AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1394 AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1395 AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1396 AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1397 AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1398 AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1399 AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1400 AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1401 AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1402 AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1403 AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1404 AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1405 AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1406 AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1407 AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1408 AggregateFunc::All => mz_expr::AggregateFunc::All,
1409 AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1410 AggregateFunc::JsonbObjectAgg { order_by } => {
1411 mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1412 }
1413 AggregateFunc::MapAgg {
1414 order_by,
1415 value_type,
1416 } => mz_expr::AggregateFunc::MapAgg {
1417 order_by,
1418 value_type,
1419 },
1420 AggregateFunc::ArrayConcat { order_by } => {
1421 mz_expr::AggregateFunc::ArrayConcat { order_by }
1422 }
1423 AggregateFunc::ListConcat { order_by } => {
1424 mz_expr::AggregateFunc::ListConcat { order_by }
1425 }
1426 AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1427 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1430 panic!("into_expr called on FusedWindowAgg")
1431 }
1432 AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1433 }
1434 }
1435
1436 pub fn identity_datum(&self) -> Datum<'static> {
1443 match self {
1444 AggregateFunc::Any => Datum::False,
1445 AggregateFunc::All => Datum::True,
1446 AggregateFunc::Dummy => Datum::Dummy,
1447 AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1448 AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1449 AggregateFunc::MaxNumeric
1450 | AggregateFunc::MaxInt16
1451 | AggregateFunc::MaxInt32
1452 | AggregateFunc::MaxInt64
1453 | AggregateFunc::MaxUInt16
1454 | AggregateFunc::MaxUInt32
1455 | AggregateFunc::MaxUInt64
1456 | AggregateFunc::MaxMzTimestamp
1457 | AggregateFunc::MaxFloat32
1458 | AggregateFunc::MaxFloat64
1459 | AggregateFunc::MaxBool
1460 | AggregateFunc::MaxString
1461 | AggregateFunc::MaxDate
1462 | AggregateFunc::MaxTimestamp
1463 | AggregateFunc::MaxTimestampTz
1464 | AggregateFunc::MaxInterval
1465 | AggregateFunc::MaxTime
1466 | AggregateFunc::MinNumeric
1467 | AggregateFunc::MinInt16
1468 | AggregateFunc::MinInt32
1469 | AggregateFunc::MinInt64
1470 | AggregateFunc::MinUInt16
1471 | AggregateFunc::MinUInt32
1472 | AggregateFunc::MinUInt64
1473 | AggregateFunc::MinMzTimestamp
1474 | AggregateFunc::MinFloat32
1475 | AggregateFunc::MinFloat64
1476 | AggregateFunc::MinBool
1477 | AggregateFunc::MinString
1478 | AggregateFunc::MinDate
1479 | AggregateFunc::MinTimestamp
1480 | AggregateFunc::MinTimestampTz
1481 | AggregateFunc::MinInterval
1482 | AggregateFunc::MinTime
1483 | AggregateFunc::SumInt16
1484 | AggregateFunc::SumInt32
1485 | AggregateFunc::SumInt64
1486 | AggregateFunc::SumUInt16
1487 | AggregateFunc::SumUInt32
1488 | AggregateFunc::SumUInt64
1489 | AggregateFunc::SumFloat32
1490 | AggregateFunc::SumFloat64
1491 | AggregateFunc::SumNumeric
1492 | AggregateFunc::Count
1493 | AggregateFunc::JsonbAgg { .. }
1494 | AggregateFunc::JsonbObjectAgg { .. }
1495 | AggregateFunc::MapAgg { .. }
1496 | AggregateFunc::StringAgg { .. } => Datum::Null,
1497 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1498 panic!("FusedWindowAgg doesn't have an identity_datum")
1508 }
1509 }
1510 }
1511
1512 pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1518 let scalar_type = match self {
1519 AggregateFunc::Count => SqlScalarType::Int64,
1520 AggregateFunc::Any => SqlScalarType::Bool,
1521 AggregateFunc::All => SqlScalarType::Bool,
1522 AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1523 AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1524 AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1525 AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1526 AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1527 max_scale: Some(NumericMaxScale::ZERO),
1528 },
1529 AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1530 AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1531 max_scale: Some(NumericMaxScale::ZERO),
1532 },
1533 AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1534 value_type: Box::new(value_type.clone()),
1535 custom_id: None,
1536 },
1537 AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1538 match input_type.scalar_type {
1539 SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1541 _ => unreachable!(),
1542 }
1543 }
1544 AggregateFunc::MaxNumeric
1545 | AggregateFunc::MaxInt16
1546 | AggregateFunc::MaxInt32
1547 | AggregateFunc::MaxInt64
1548 | AggregateFunc::MaxUInt16
1549 | AggregateFunc::MaxUInt32
1550 | AggregateFunc::MaxUInt64
1551 | AggregateFunc::MaxMzTimestamp
1552 | AggregateFunc::MaxFloat32
1553 | AggregateFunc::MaxFloat64
1554 | AggregateFunc::MaxBool
1555 | AggregateFunc::MaxString
1556 | AggregateFunc::MaxDate
1557 | AggregateFunc::MaxTimestamp
1558 | AggregateFunc::MaxTimestampTz
1559 | AggregateFunc::MaxInterval
1560 | AggregateFunc::MaxTime
1561 | AggregateFunc::MinNumeric
1562 | AggregateFunc::MinInt16
1563 | AggregateFunc::MinInt32
1564 | AggregateFunc::MinInt64
1565 | AggregateFunc::MinUInt16
1566 | AggregateFunc::MinUInt32
1567 | AggregateFunc::MinUInt64
1568 | AggregateFunc::MinMzTimestamp
1569 | AggregateFunc::MinFloat32
1570 | AggregateFunc::MinFloat64
1571 | AggregateFunc::MinBool
1572 | AggregateFunc::MinString
1573 | AggregateFunc::MinDate
1574 | AggregateFunc::MinTimestamp
1575 | AggregateFunc::MinTimestampTz
1576 | AggregateFunc::MinInterval
1577 | AggregateFunc::MinTime
1578 | AggregateFunc::SumFloat32
1579 | AggregateFunc::SumFloat64
1580 | AggregateFunc::SumNumeric
1581 | AggregateFunc::Dummy => input_type.scalar_type,
1582 AggregateFunc::FusedWindowAgg { funcs } => {
1583 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1584 SqlScalarType::Record {
1585 fields: funcs
1586 .iter()
1587 .zip_eq(input_types)
1588 .map(|(f, t)| (ColumnName::from(""), f.output_sql_type(t.clone())))
1589 .collect(),
1590 custom_id: None,
1591 }
1592 }
1593 };
1594 let nullable = !matches!(self, AggregateFunc::Count);
1596 scalar_type.nullable(nullable)
1597 }
1598
1599 pub fn is_order_sensitive(&self) -> bool {
1600 use AggregateFunc::*;
1601 matches!(
1602 self,
1603 JsonbAgg { .. }
1604 | JsonbObjectAgg { .. }
1605 | MapAgg { .. }
1606 | ArrayConcat { .. }
1607 | ListConcat { .. }
1608 | StringAgg { .. }
1609 )
1610 }
1611}
1612
1613impl HirRelationExpr {
1614 pub fn top_level_typ(&self) -> SqlRelationType {
1616 self.typ(&[], &BTreeMap::new())
1617 }
1618
1619 pub fn typ(
1624 &self,
1625 outers: &[SqlRelationType],
1626 params: &BTreeMap<usize, SqlScalarType>,
1627 ) -> SqlRelationType {
1628 stack::maybe_grow(|| match self {
1629 HirRelationExpr::Constant { typ, .. } => typ.clone(),
1630 HirRelationExpr::Get { typ, .. } => typ.clone(),
1631 HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1632 HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1633 HirRelationExpr::Project { input, outputs } => {
1634 let input_typ = input.typ(outers, params);
1635 SqlRelationType::new(
1636 outputs
1637 .iter()
1638 .map(|&i| input_typ.column_types[i].clone())
1639 .collect(),
1640 )
1641 }
1642 HirRelationExpr::Map { input, scalars } => {
1643 let mut typ = input.typ(outers, params);
1644 for scalar in scalars {
1645 typ.column_types.push(scalar.typ(outers, &typ, params));
1646 }
1647 typ
1648 }
1649 HirRelationExpr::CallTable { func, exprs: _ } => func.output_sql_type(),
1650 HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1651 input.typ(outers, params)
1652 }
1653 HirRelationExpr::Join {
1654 left, right, kind, ..
1655 } => {
1656 let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1657 let right_nullable =
1658 matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1659 let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1660 let nullable = t.nullable || left_nullable;
1661 t.nullable(nullable)
1662 });
1663 let mut outers = outers.to_vec();
1664 outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1665 let rt = right
1666 .typ(&outers, params)
1667 .column_types
1668 .into_iter()
1669 .map(|t| {
1670 let nullable = t.nullable || right_nullable;
1671 t.nullable(nullable)
1672 });
1673 SqlRelationType::new(lt.chain(rt).collect())
1674 }
1675 HirRelationExpr::Reduce {
1676 input,
1677 group_key,
1678 aggregates,
1679 expected_group_size: _,
1680 } => {
1681 let input_typ = input.typ(outers, params);
1682 let mut column_types = group_key
1683 .iter()
1684 .map(|&i| input_typ.column_types[i].clone())
1685 .collect::<Vec<_>>();
1686 for agg in aggregates {
1687 column_types.push(agg.typ(outers, &input_typ, params));
1688 }
1689 SqlRelationType::new(column_types)
1691 }
1692 HirRelationExpr::Distinct { input }
1694 | HirRelationExpr::Negate { input }
1695 | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1696 HirRelationExpr::Union { base, inputs } => {
1697 let mut base_cols = base.typ(outers, params).column_types;
1698 for input in inputs {
1699 for (base_col, col) in base_cols
1700 .iter_mut()
1701 .zip_eq(input.typ(outers, params).column_types)
1702 {
1703 *base_col = base_col.sql_union(&col).unwrap(); }
1705 }
1706 SqlRelationType::new(base_cols)
1707 }
1708 })
1709 }
1710
1711 pub fn arity(&self) -> usize {
1712 match self {
1713 HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1714 HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1715 HirRelationExpr::Let { body, .. } => body.arity(),
1716 HirRelationExpr::LetRec { body, .. } => body.arity(),
1717 HirRelationExpr::Project { outputs, .. } => outputs.len(),
1718 HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1719 HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1720 HirRelationExpr::Filter { input, .. }
1721 | HirRelationExpr::TopK { input, .. }
1722 | HirRelationExpr::Distinct { input }
1723 | HirRelationExpr::Negate { input }
1724 | HirRelationExpr::Threshold { input } => input.arity(),
1725 HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1726 HirRelationExpr::Union { base, .. } => base.arity(),
1727 HirRelationExpr::Reduce {
1728 group_key,
1729 aggregates,
1730 ..
1731 } => group_key.len() + aggregates.len(),
1732 }
1733 }
1734
1735 pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1737 match self {
1738 Self::Constant { rows, typ } => Some((rows, typ)),
1739 _ => None,
1740 }
1741 }
1742
1743 pub fn is_correlated(&self) -> bool {
1746 let mut correlated = false;
1747 #[allow(deprecated)]
1748 self.visit_columns(0, &mut |depth, col| {
1749 if col.level > depth && col.level - depth == 1 {
1750 correlated = true;
1751 }
1752 });
1753 correlated
1754 }
1755
1756 pub fn is_join_identity(&self) -> bool {
1757 match self {
1758 HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1759 _ => false,
1760 }
1761 }
1762
1763 pub fn project(self, outputs: Vec<usize>) -> Self {
1764 if outputs.iter().copied().eq(0..self.arity()) {
1765 self
1767 } else {
1768 HirRelationExpr::Project {
1769 input: Box::new(self),
1770 outputs,
1771 }
1772 }
1773 }
1774
1775 pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1776 if scalars.is_empty() {
1777 self
1779 } else if let HirRelationExpr::Map {
1780 scalars: old_scalars,
1781 input: _,
1782 } = &mut self
1783 {
1784 old_scalars.extend(scalars);
1786 self
1787 } else {
1788 HirRelationExpr::Map {
1789 input: Box::new(self),
1790 scalars,
1791 }
1792 }
1793 }
1794
1795 pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1796 if let HirRelationExpr::Filter {
1797 input: _,
1798 predicates,
1799 } = &mut self
1800 {
1801 predicates.extend(preds);
1802 predicates.sort();
1803 predicates.dedup();
1804 self
1805 } else {
1806 preds.sort();
1807 preds.dedup();
1808 HirRelationExpr::Filter {
1809 input: Box::new(self),
1810 predicates: preds,
1811 }
1812 }
1813 }
1814
1815 pub fn reduce(
1816 self,
1817 group_key: Vec<usize>,
1818 aggregates: Vec<AggregateExpr>,
1819 expected_group_size: Option<u64>,
1820 ) -> Self {
1821 HirRelationExpr::Reduce {
1822 input: Box::new(self),
1823 group_key,
1824 aggregates,
1825 expected_group_size,
1826 }
1827 }
1828
1829 pub fn top_k(
1830 self,
1831 group_key: Vec<usize>,
1832 order_key: Vec<ColumnOrder>,
1833 limit: Option<HirScalarExpr>,
1834 offset: HirScalarExpr,
1835 expected_group_size: Option<u64>,
1836 ) -> Self {
1837 HirRelationExpr::TopK {
1838 input: Box::new(self),
1839 group_key,
1840 order_key,
1841 limit,
1842 offset,
1843 expected_group_size,
1844 }
1845 }
1846
1847 pub fn negate(self) -> Self {
1848 if let HirRelationExpr::Negate { input } = self {
1849 *input
1850 } else {
1851 HirRelationExpr::Negate {
1852 input: Box::new(self),
1853 }
1854 }
1855 }
1856
1857 pub fn distinct(self) -> Self {
1858 if let HirRelationExpr::Distinct { .. } = self {
1859 self
1860 } else {
1861 HirRelationExpr::Distinct {
1862 input: Box::new(self),
1863 }
1864 }
1865 }
1866
1867 pub fn threshold(self) -> Self {
1868 if let HirRelationExpr::Threshold { .. } = self {
1869 self
1870 } else {
1871 HirRelationExpr::Threshold {
1872 input: Box::new(self),
1873 }
1874 }
1875 }
1876
1877 pub fn union(self, other: Self) -> Self {
1878 let mut terms = Vec::new();
1879 if let HirRelationExpr::Union { base, inputs } = self {
1880 terms.push(*base);
1881 terms.extend(inputs);
1882 } else {
1883 terms.push(self);
1884 }
1885 if let HirRelationExpr::Union { base, inputs } = other {
1886 terms.push(*base);
1887 terms.extend(inputs);
1888 } else {
1889 terms.push(other);
1890 }
1891 HirRelationExpr::Union {
1892 base: Box::new(terms.remove(0)),
1893 inputs: terms,
1894 }
1895 }
1896
1897 pub fn exists(self) -> HirScalarExpr {
1898 HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1899 }
1900
1901 pub fn select(self) -> HirScalarExpr {
1902 HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1903 }
1904
1905 pub fn join(
1906 self,
1907 mut right: HirRelationExpr,
1908 on: HirScalarExpr,
1909 kind: JoinKind,
1910 ) -> HirRelationExpr {
1911 if self.is_join_identity()
1912 && !right.is_correlated()
1913 && on == HirScalarExpr::literal_true()
1914 && kind.can_elide_identity_left_join()
1915 {
1916 #[allow(deprecated)]
1920 right.visit_columns_mut(0, &mut |depth, col| {
1921 if col.level > depth {
1922 col.level -= 1;
1923 }
1924 });
1925 right
1926 } else if right.is_join_identity()
1927 && on == HirScalarExpr::literal_true()
1928 && kind.can_elide_identity_right_join()
1929 {
1930 self
1931 } else {
1932 HirRelationExpr::Join {
1933 left: Box::new(self),
1934 right: Box::new(right),
1935 on,
1936 kind,
1937 }
1938 }
1939 }
1940
1941 pub fn take(&mut self) -> HirRelationExpr {
1942 mem::replace(
1943 self,
1944 HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1945 )
1946 }
1947
1948 #[deprecated = "Use `Visit::visit_post`."]
1949 pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1950 where
1951 F: FnMut(&'a Self, usize),
1952 {
1953 #[allow(deprecated)]
1954 let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1955 depth: usize|
1956 -> Result<(), ()> {
1957 f(e, depth);
1958 Ok(())
1959 });
1960 }
1961
1962 #[deprecated = "Use `Visit::try_visit_post`."]
1963 pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1964 where
1965 F: FnMut(&'a Self, usize) -> Result<(), E>,
1966 {
1967 #[allow(deprecated)]
1968 self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1969 e.visit_fallible(depth, f)
1970 })?;
1971 f(self, depth)
1972 }
1973
1974 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1975 pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1976 where
1977 F: FnMut(&'a Self, usize) -> Result<(), E>,
1978 {
1979 match self {
1980 HirRelationExpr::Constant { .. }
1981 | HirRelationExpr::Get { .. }
1982 | HirRelationExpr::CallTable { .. } => (),
1983 HirRelationExpr::Let { body, value, .. } => {
1984 f(value, depth)?;
1985 f(body, depth)?;
1986 }
1987 HirRelationExpr::LetRec {
1988 limit: _,
1989 bindings,
1990 body,
1991 } => {
1992 for (_, _, value, _) in bindings.iter() {
1993 f(value, depth)?;
1994 }
1995 f(body, depth)?;
1996 }
1997 HirRelationExpr::Project { input, .. } => {
1998 f(input, depth)?;
1999 }
2000 HirRelationExpr::Map { input, .. } => {
2001 f(input, depth)?;
2002 }
2003 HirRelationExpr::Filter { input, .. } => {
2004 f(input, depth)?;
2005 }
2006 HirRelationExpr::Join { left, right, .. } => {
2007 f(left, depth)?;
2008 f(right, depth + 1)?;
2009 }
2010 HirRelationExpr::Reduce { input, .. } => {
2011 f(input, depth)?;
2012 }
2013 HirRelationExpr::Distinct { input } => {
2014 f(input, depth)?;
2015 }
2016 HirRelationExpr::TopK { input, .. } => {
2017 f(input, depth)?;
2018 }
2019 HirRelationExpr::Negate { input } => {
2020 f(input, depth)?;
2021 }
2022 HirRelationExpr::Threshold { input } => {
2023 f(input, depth)?;
2024 }
2025 HirRelationExpr::Union { base, inputs } => {
2026 f(base, depth)?;
2027 for input in inputs {
2028 f(input, depth)?;
2029 }
2030 }
2031 }
2032 Ok(())
2033 }
2034
2035 #[deprecated = "Use `Visit::visit_mut_post` instead."]
2036 pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2037 where
2038 F: FnMut(&mut Self, usize),
2039 {
2040 #[allow(deprecated)]
2041 let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2042 depth: usize|
2043 -> Result<(), ()> {
2044 f(e, depth);
2045 Ok(())
2046 });
2047 }
2048
2049 #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2050 pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2051 where
2052 F: FnMut(&mut Self, usize) -> Result<(), E>,
2053 {
2054 #[allow(deprecated)]
2055 self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2056 e.visit_mut_fallible(depth, f)
2057 })?;
2058 f(self, depth)
2059 }
2060
2061 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2062 pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2063 where
2064 F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2065 {
2066 match self {
2067 HirRelationExpr::Constant { .. }
2068 | HirRelationExpr::Get { .. }
2069 | HirRelationExpr::CallTable { .. } => (),
2070 HirRelationExpr::Let { body, value, .. } => {
2071 f(value, depth)?;
2072 f(body, depth)?;
2073 }
2074 HirRelationExpr::LetRec {
2075 limit: _,
2076 bindings,
2077 body,
2078 } => {
2079 for (_, _, value, _) in bindings.iter_mut() {
2080 f(value, depth)?;
2081 }
2082 f(body, depth)?;
2083 }
2084 HirRelationExpr::Project { input, .. } => {
2085 f(input, depth)?;
2086 }
2087 HirRelationExpr::Map { input, .. } => {
2088 f(input, depth)?;
2089 }
2090 HirRelationExpr::Filter { input, .. } => {
2091 f(input, depth)?;
2092 }
2093 HirRelationExpr::Join { left, right, .. } => {
2094 f(left, depth)?;
2095 f(right, depth + 1)?;
2096 }
2097 HirRelationExpr::Reduce { input, .. } => {
2098 f(input, depth)?;
2099 }
2100 HirRelationExpr::Distinct { input } => {
2101 f(input, depth)?;
2102 }
2103 HirRelationExpr::TopK { input, .. } => {
2104 f(input, depth)?;
2105 }
2106 HirRelationExpr::Negate { input } => {
2107 f(input, depth)?;
2108 }
2109 HirRelationExpr::Threshold { input } => {
2110 f(input, depth)?;
2111 }
2112 HirRelationExpr::Union { base, inputs } => {
2113 f(base, depth)?;
2114 for input in inputs {
2115 f(input, depth)?;
2116 }
2117 }
2118 }
2119 Ok(())
2120 }
2121
2122 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2123 pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2129 where
2130 F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2131 {
2132 #[allow(deprecated)]
2133 self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2134 depth: usize|
2135 -> Result<(), E> {
2136 match e {
2137 HirRelationExpr::Join { on, .. } => {
2138 f(on, depth)?;
2139 }
2140 HirRelationExpr::Map { scalars, .. } => {
2141 for scalar in scalars {
2142 f(scalar, depth)?;
2143 }
2144 }
2145 HirRelationExpr::CallTable { exprs, .. } => {
2146 for expr in exprs {
2147 f(expr, depth)?;
2148 }
2149 }
2150 HirRelationExpr::Filter { predicates, .. } => {
2151 for predicate in predicates {
2152 f(predicate, depth)?;
2153 }
2154 }
2155 HirRelationExpr::Reduce { aggregates, .. } => {
2156 for aggregate in aggregates {
2157 f(&aggregate.expr, depth)?;
2158 }
2159 }
2160 HirRelationExpr::TopK { limit, offset, .. } => {
2161 if let Some(limit) = limit {
2162 f(limit, depth)?;
2163 }
2164 f(offset, depth)?;
2165 }
2166 HirRelationExpr::Union { .. }
2167 | HirRelationExpr::Let { .. }
2168 | HirRelationExpr::LetRec { .. }
2169 | HirRelationExpr::Project { .. }
2170 | HirRelationExpr::Distinct { .. }
2171 | HirRelationExpr::Negate { .. }
2172 | HirRelationExpr::Threshold { .. }
2173 | HirRelationExpr::Constant { .. }
2174 | HirRelationExpr::Get { .. } => (),
2175 }
2176 Ok(())
2177 })
2178 }
2179
2180 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2181 pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2183 where
2184 F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2185 {
2186 #[allow(deprecated)]
2187 self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2188 depth: usize|
2189 -> Result<(), E> {
2190 match e {
2191 HirRelationExpr::Join { on, .. } => {
2192 f(on, depth)?;
2193 }
2194 HirRelationExpr::Map { scalars, .. } => {
2195 for scalar in scalars.iter_mut() {
2196 f(scalar, depth)?;
2197 }
2198 }
2199 HirRelationExpr::CallTable { exprs, .. } => {
2200 for expr in exprs.iter_mut() {
2201 f(expr, depth)?;
2202 }
2203 }
2204 HirRelationExpr::Filter { predicates, .. } => {
2205 for predicate in predicates.iter_mut() {
2206 f(predicate, depth)?;
2207 }
2208 }
2209 HirRelationExpr::Reduce { aggregates, .. } => {
2210 for aggregate in aggregates.iter_mut() {
2211 f(&mut aggregate.expr, depth)?;
2212 }
2213 }
2214 HirRelationExpr::TopK { limit, offset, .. } => {
2215 if let Some(limit) = limit {
2216 f(limit, depth)?;
2217 }
2218 f(offset, depth)?;
2219 }
2220 HirRelationExpr::Union { .. }
2221 | HirRelationExpr::Let { .. }
2222 | HirRelationExpr::LetRec { .. }
2223 | HirRelationExpr::Project { .. }
2224 | HirRelationExpr::Distinct { .. }
2225 | HirRelationExpr::Negate { .. }
2226 | HirRelationExpr::Threshold { .. }
2227 | HirRelationExpr::Constant { .. }
2228 | HirRelationExpr::Get { .. } => (),
2229 }
2230 Ok(())
2231 })
2232 }
2233
2234 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2235 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2241 where
2242 F: FnMut(usize, &ColumnRef),
2243 {
2244 #[allow(deprecated)]
2245 let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2246 depth: usize|
2247 -> Result<(), ()> {
2248 e.visit_columns(depth, f);
2249 Ok(())
2250 });
2251 }
2252
2253 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2254 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2256 where
2257 F: FnMut(usize, &mut ColumnRef),
2258 {
2259 #[allow(deprecated)]
2260 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2261 depth: usize|
2262 -> Result<(), ()> {
2263 e.visit_columns_mut(depth, f);
2264 Ok(())
2265 });
2266 }
2267
2268 pub fn bind_parameters(
2271 &mut self,
2272 scx: &StatementContext,
2273 lifetime: QueryLifetime,
2274 params: &Params,
2275 ) -> Result<(), PlanError> {
2276 #[allow(deprecated)]
2277 self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2278 e.bind_parameters(scx, lifetime, params)
2279 })
2280 }
2281
2282 pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2283 let mut contains_parameters = false;
2284 #[allow(deprecated)]
2285 self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2286 if e.contains_parameters() {
2287 contains_parameters = true;
2288 }
2289 Ok::<(), PlanError>(())
2290 })?;
2291 Ok(contains_parameters)
2292 }
2293
2294 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2296 #[allow(deprecated)]
2297 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2298 depth: usize|
2299 -> Result<(), ()> {
2300 e.splice_parameters(params, depth);
2301 Ok(())
2302 });
2303 }
2304
2305 pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2307 let rows = rows
2308 .into_iter()
2309 .map(move |datums| Row::pack_slice(&datums))
2310 .collect();
2311 HirRelationExpr::Constant { rows, typ }
2312 }
2313
2314 pub fn finish_maintained(
2320 &mut self,
2321 finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2322 group_size_hints: GroupSizeHints,
2323 ) {
2324 if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2325 let old_finishing = mem::replace(
2326 finishing,
2327 HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2328 );
2329 *self = HirRelationExpr::top_k(
2330 std::mem::replace(
2331 self,
2332 HirRelationExpr::Constant {
2333 rows: vec![],
2334 typ: SqlRelationType::new(Vec::new()),
2335 },
2336 ),
2337 vec![],
2338 old_finishing.order_by,
2339 old_finishing.limit,
2340 old_finishing.offset,
2341 group_size_hints.limit_input_group_size,
2342 )
2343 .project(old_finishing.project);
2344 }
2345 }
2346
2347 pub fn trivial_row_set_finishing_hir(
2352 arity: usize,
2353 ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2354 RowSetFinishing {
2355 order_by: Vec::new(),
2356 limit: None,
2357 offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2358 project: (0..arity).collect(),
2359 }
2360 }
2361
2362 pub fn is_trivial_row_set_finishing_hir(
2367 rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2368 arity: usize,
2369 ) -> bool {
2370 rsf.limit.is_none()
2371 && rsf.order_by.is_empty()
2372 && rsf
2373 .offset
2374 .clone()
2375 .try_into_literal_int64()
2376 .is_ok_and(|o| o == 0)
2377 && rsf.project.iter().copied().eq(0..arity)
2378 }
2379
2380 pub fn could_run_expensive_function(&self) -> bool {
2390 let mut result = false;
2391 if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2392 use HirRelationExpr::*;
2393 use HirScalarExpr::*;
2394
2395 e.visit_children(|scalar: &HirScalarExpr| {
2396 if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2397 result |= match scalar {
2398 Column(..)
2399 | Literal(..)
2400 | CallUnmaterializable(..)
2401 | If { .. }
2402 | Parameter(..)
2403 | Select(..)
2404 | Exists(..) => false,
2405 CallUnary { .. }
2407 | CallBinary { .. }
2408 | CallVariadic { .. }
2409 | Windowing(..) => true,
2410 };
2411 }) {
2412 result = true;
2414 }
2415 });
2416
2417 result |= matches!(e, CallTable { .. } | Reduce { .. });
2420 }) {
2421 result = true;
2423 }
2424
2425 result
2426 }
2427
2428 pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2430 let mut contains = false;
2431 self.visit_post(&mut |expr| {
2432 expr.visit_children(|expr: &HirScalarExpr| {
2433 contains = contains || expr.contains_temporal()
2434 })
2435 })?;
2436 Ok(contains)
2437 }
2438
2439 pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2441 let mut contains = false;
2442 self.visit_post(&mut |expr| {
2443 expr.visit_children(|expr: &HirScalarExpr| {
2444 contains = contains || expr.contains_unmaterializable()
2445 })
2446 })?;
2447 Ok(contains)
2448 }
2449
2450 pub fn contains_unmaterializable_except_temporal(&self) -> Result<bool, RecursionLimitError> {
2453 let mut contains = false;
2454 self.visit_post(&mut |expr| {
2455 expr.visit_children(|expr: &HirScalarExpr| {
2456 contains = contains || expr.contains_unmaterializable_except_temporal()
2457 })
2458 })?;
2459 Ok(contains)
2460 }
2461}
2462
2463impl CollectionPlan for HirRelationExpr {
2464 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2471 if let Self::Get {
2472 id: Id::Global(id), ..
2473 } = self
2474 {
2475 out.insert(*id);
2476 }
2477 self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2478 }
2479}
2480
2481impl VisitChildren<Self> for HirRelationExpr {
2482 fn visit_children<F>(&self, mut f: F)
2483 where
2484 F: FnMut(&Self),
2485 {
2486 VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2490 #[allow(deprecated)]
2491 Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2492 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2493 f(expr.as_ref())
2494 }
2495 _ => (),
2496 });
2497 });
2498
2499 use HirRelationExpr::*;
2500 match self {
2501 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2502 Let {
2503 name: _,
2504 id: _,
2505 value,
2506 body,
2507 } => {
2508 f(value);
2509 f(body);
2510 }
2511 LetRec {
2512 limit: _,
2513 bindings,
2514 body,
2515 } => {
2516 for (_, _, value, _) in bindings.iter() {
2517 f(value);
2518 }
2519 f(body);
2520 }
2521 Project { input, outputs: _ } => f(input),
2522 Map { input, scalars: _ } => {
2523 f(input);
2524 }
2525 CallTable { func: _, exprs: _ } => (),
2526 Filter {
2527 input,
2528 predicates: _,
2529 } => {
2530 f(input);
2531 }
2532 Join {
2533 left,
2534 right,
2535 on: _,
2536 kind: _,
2537 } => {
2538 f(left);
2539 f(right);
2540 }
2541 Reduce {
2542 input,
2543 group_key: _,
2544 aggregates: _,
2545 expected_group_size: _,
2546 } => {
2547 f(input);
2548 }
2549 Distinct { input }
2550 | TopK {
2551 input,
2552 group_key: _,
2553 order_key: _,
2554 limit: _,
2555 offset: _,
2556 expected_group_size: _,
2557 }
2558 | Negate { input }
2559 | Threshold { input } => {
2560 f(input);
2561 }
2562 Union { base, inputs } => {
2563 f(base);
2564 for input in inputs {
2565 f(input);
2566 }
2567 }
2568 }
2569 }
2570
2571 fn visit_mut_children<F>(&mut self, mut f: F)
2572 where
2573 F: FnMut(&mut Self),
2574 {
2575 VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2579 #[allow(deprecated)]
2580 Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2581 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2582 f(expr.as_mut())
2583 }
2584 _ => (),
2585 });
2586 });
2587
2588 use HirRelationExpr::*;
2589 match self {
2590 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2591 Let {
2592 name: _,
2593 id: _,
2594 value,
2595 body,
2596 } => {
2597 f(value);
2598 f(body);
2599 }
2600 LetRec {
2601 limit: _,
2602 bindings,
2603 body,
2604 } => {
2605 for (_, _, value, _) in bindings.iter_mut() {
2606 f(value);
2607 }
2608 f(body);
2609 }
2610 Project { input, outputs: _ } => f(input),
2611 Map { input, scalars: _ } => {
2612 f(input);
2613 }
2614 CallTable { func: _, exprs: _ } => (),
2615 Filter {
2616 input,
2617 predicates: _,
2618 } => {
2619 f(input);
2620 }
2621 Join {
2622 left,
2623 right,
2624 on: _,
2625 kind: _,
2626 } => {
2627 f(left);
2628 f(right);
2629 }
2630 Reduce {
2631 input,
2632 group_key: _,
2633 aggregates: _,
2634 expected_group_size: _,
2635 } => {
2636 f(input);
2637 }
2638 Distinct { input }
2639 | TopK {
2640 input,
2641 group_key: _,
2642 order_key: _,
2643 limit: _,
2644 offset: _,
2645 expected_group_size: _,
2646 }
2647 | Negate { input }
2648 | Threshold { input } => {
2649 f(input);
2650 }
2651 Union { base, inputs } => {
2652 f(base);
2653 for input in inputs {
2654 f(input);
2655 }
2656 }
2657 }
2658 }
2659
2660 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2661 where
2662 F: FnMut(&Self) -> Result<(), E>,
2663 E: From<RecursionLimitError>,
2664 {
2665 VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2669 Visit::try_visit_post(expr, &mut |expr| match expr {
2670 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2671 f(expr.as_ref())
2672 }
2673 _ => Ok(()),
2674 })
2675 })?;
2676
2677 use HirRelationExpr::*;
2678 match self {
2679 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2680 Let {
2681 name: _,
2682 id: _,
2683 value,
2684 body,
2685 } => {
2686 f(value)?;
2687 f(body)?;
2688 }
2689 LetRec {
2690 limit: _,
2691 bindings,
2692 body,
2693 } => {
2694 for (_, _, value, _) in bindings.iter() {
2695 f(value)?;
2696 }
2697 f(body)?;
2698 }
2699 Project { input, outputs: _ } => f(input)?,
2700 Map { input, scalars: _ } => {
2701 f(input)?;
2702 }
2703 CallTable { func: _, exprs: _ } => (),
2704 Filter {
2705 input,
2706 predicates: _,
2707 } => {
2708 f(input)?;
2709 }
2710 Join {
2711 left,
2712 right,
2713 on: _,
2714 kind: _,
2715 } => {
2716 f(left)?;
2717 f(right)?;
2718 }
2719 Reduce {
2720 input,
2721 group_key: _,
2722 aggregates: _,
2723 expected_group_size: _,
2724 } => {
2725 f(input)?;
2726 }
2727 Distinct { input }
2728 | TopK {
2729 input,
2730 group_key: _,
2731 order_key: _,
2732 limit: _,
2733 offset: _,
2734 expected_group_size: _,
2735 }
2736 | Negate { input }
2737 | Threshold { input } => {
2738 f(input)?;
2739 }
2740 Union { base, inputs } => {
2741 f(base)?;
2742 for input in inputs {
2743 f(input)?;
2744 }
2745 }
2746 }
2747 Ok(())
2748 }
2749
2750 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2751 where
2752 F: FnMut(&mut Self) -> Result<(), E>,
2753 E: From<RecursionLimitError>,
2754 {
2755 VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2759 Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2760 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2761 f(expr.as_mut())
2762 }
2763 _ => Ok(()),
2764 })
2765 })?;
2766
2767 use HirRelationExpr::*;
2768 match self {
2769 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2770 Let {
2771 name: _,
2772 id: _,
2773 value,
2774 body,
2775 } => {
2776 f(value)?;
2777 f(body)?;
2778 }
2779 LetRec {
2780 limit: _,
2781 bindings,
2782 body,
2783 } => {
2784 for (_, _, value, _) in bindings.iter_mut() {
2785 f(value)?;
2786 }
2787 f(body)?;
2788 }
2789 Project { input, outputs: _ } => f(input)?,
2790 Map { input, scalars: _ } => {
2791 f(input)?;
2792 }
2793 CallTable { func: _, exprs: _ } => (),
2794 Filter {
2795 input,
2796 predicates: _,
2797 } => {
2798 f(input)?;
2799 }
2800 Join {
2801 left,
2802 right,
2803 on: _,
2804 kind: _,
2805 } => {
2806 f(left)?;
2807 f(right)?;
2808 }
2809 Reduce {
2810 input,
2811 group_key: _,
2812 aggregates: _,
2813 expected_group_size: _,
2814 } => {
2815 f(input)?;
2816 }
2817 Distinct { input }
2818 | TopK {
2819 input,
2820 group_key: _,
2821 order_key: _,
2822 limit: _,
2823 offset: _,
2824 expected_group_size: _,
2825 }
2826 | Negate { input }
2827 | Threshold { input } => {
2828 f(input)?;
2829 }
2830 Union { base, inputs } => {
2831 f(base)?;
2832 for input in inputs {
2833 f(input)?;
2834 }
2835 }
2836 }
2837 Ok(())
2838 }
2839}
2840
2841impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2842 fn visit_children<F>(&self, mut f: F)
2843 where
2844 F: FnMut(&HirScalarExpr),
2845 {
2846 use HirRelationExpr::*;
2847 match self {
2848 Constant { rows: _, typ: _ }
2849 | Get { id: _, typ: _ }
2850 | Let {
2851 name: _,
2852 id: _,
2853 value: _,
2854 body: _,
2855 }
2856 | LetRec {
2857 limit: _,
2858 bindings: _,
2859 body: _,
2860 }
2861 | Project {
2862 input: _,
2863 outputs: _,
2864 } => (),
2865 Map { input: _, scalars } => {
2866 for scalar in scalars {
2867 f(scalar);
2868 }
2869 }
2870 CallTable { func: _, exprs } => {
2871 for expr in exprs {
2872 f(expr);
2873 }
2874 }
2875 Filter {
2876 input: _,
2877 predicates,
2878 } => {
2879 for predicate in predicates {
2880 f(predicate);
2881 }
2882 }
2883 Join {
2884 left: _,
2885 right: _,
2886 on,
2887 kind: _,
2888 } => f(on),
2889 Reduce {
2890 input: _,
2891 group_key: _,
2892 aggregates,
2893 expected_group_size: _,
2894 } => {
2895 for aggregate in aggregates {
2896 f(aggregate.expr.as_ref());
2897 }
2898 }
2899 TopK {
2900 input: _,
2901 group_key: _,
2902 order_key: _,
2903 limit,
2904 offset,
2905 expected_group_size: _,
2906 } => {
2907 if let Some(limit) = limit {
2908 f(limit)
2909 }
2910 f(offset)
2911 }
2912 Distinct { input: _ }
2913 | Negate { input: _ }
2914 | Threshold { input: _ }
2915 | Union { base: _, inputs: _ } => (),
2916 }
2917 }
2918
2919 fn visit_mut_children<F>(&mut self, mut f: F)
2920 where
2921 F: FnMut(&mut HirScalarExpr),
2922 {
2923 use HirRelationExpr::*;
2924 match self {
2925 Constant { rows: _, typ: _ }
2926 | Get { id: _, typ: _ }
2927 | Let {
2928 name: _,
2929 id: _,
2930 value: _,
2931 body: _,
2932 }
2933 | LetRec {
2934 limit: _,
2935 bindings: _,
2936 body: _,
2937 }
2938 | Project {
2939 input: _,
2940 outputs: _,
2941 } => (),
2942 Map { input: _, scalars } => {
2943 for scalar in scalars {
2944 f(scalar);
2945 }
2946 }
2947 CallTable { func: _, exprs } => {
2948 for expr in exprs {
2949 f(expr);
2950 }
2951 }
2952 Filter {
2953 input: _,
2954 predicates,
2955 } => {
2956 for predicate in predicates {
2957 f(predicate);
2958 }
2959 }
2960 Join {
2961 left: _,
2962 right: _,
2963 on,
2964 kind: _,
2965 } => f(on),
2966 Reduce {
2967 input: _,
2968 group_key: _,
2969 aggregates,
2970 expected_group_size: _,
2971 } => {
2972 for aggregate in aggregates {
2973 f(aggregate.expr.as_mut());
2974 }
2975 }
2976 TopK {
2977 input: _,
2978 group_key: _,
2979 order_key: _,
2980 limit,
2981 offset,
2982 expected_group_size: _,
2983 } => {
2984 if let Some(limit) = limit {
2985 f(limit)
2986 }
2987 f(offset)
2988 }
2989 Distinct { input: _ }
2990 | Negate { input: _ }
2991 | Threshold { input: _ }
2992 | Union { base: _, inputs: _ } => (),
2993 }
2994 }
2995
2996 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2997 where
2998 F: FnMut(&HirScalarExpr) -> Result<(), E>,
2999 E: From<RecursionLimitError>,
3000 {
3001 use HirRelationExpr::*;
3002 match self {
3003 Constant { rows: _, typ: _ }
3004 | Get { id: _, typ: _ }
3005 | Let {
3006 name: _,
3007 id: _,
3008 value: _,
3009 body: _,
3010 }
3011 | LetRec {
3012 limit: _,
3013 bindings: _,
3014 body: _,
3015 }
3016 | Project {
3017 input: _,
3018 outputs: _,
3019 } => (),
3020 Map { input: _, scalars } => {
3021 for scalar in scalars {
3022 f(scalar)?;
3023 }
3024 }
3025 CallTable { func: _, exprs } => {
3026 for expr in exprs {
3027 f(expr)?;
3028 }
3029 }
3030 Filter {
3031 input: _,
3032 predicates,
3033 } => {
3034 for predicate in predicates {
3035 f(predicate)?;
3036 }
3037 }
3038 Join {
3039 left: _,
3040 right: _,
3041 on,
3042 kind: _,
3043 } => f(on)?,
3044 Reduce {
3045 input: _,
3046 group_key: _,
3047 aggregates,
3048 expected_group_size: _,
3049 } => {
3050 for aggregate in aggregates {
3051 f(aggregate.expr.as_ref())?;
3052 }
3053 }
3054 TopK {
3055 input: _,
3056 group_key: _,
3057 order_key: _,
3058 limit,
3059 offset,
3060 expected_group_size: _,
3061 } => {
3062 if let Some(limit) = limit {
3063 f(limit)?
3064 }
3065 f(offset)?
3066 }
3067 Distinct { input: _ }
3068 | Negate { input: _ }
3069 | Threshold { input: _ }
3070 | Union { base: _, inputs: _ } => (),
3071 }
3072 Ok(())
3073 }
3074
3075 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3076 where
3077 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3078 E: From<RecursionLimitError>,
3079 {
3080 use HirRelationExpr::*;
3081 match self {
3082 Constant { rows: _, typ: _ }
3083 | Get { id: _, typ: _ }
3084 | Let {
3085 name: _,
3086 id: _,
3087 value: _,
3088 body: _,
3089 }
3090 | LetRec {
3091 limit: _,
3092 bindings: _,
3093 body: _,
3094 }
3095 | Project {
3096 input: _,
3097 outputs: _,
3098 } => (),
3099 Map { input: _, scalars } => {
3100 for scalar in scalars {
3101 f(scalar)?;
3102 }
3103 }
3104 CallTable { func: _, exprs } => {
3105 for expr in exprs {
3106 f(expr)?;
3107 }
3108 }
3109 Filter {
3110 input: _,
3111 predicates,
3112 } => {
3113 for predicate in predicates {
3114 f(predicate)?;
3115 }
3116 }
3117 Join {
3118 left: _,
3119 right: _,
3120 on,
3121 kind: _,
3122 } => f(on)?,
3123 Reduce {
3124 input: _,
3125 group_key: _,
3126 aggregates,
3127 expected_group_size: _,
3128 } => {
3129 for aggregate in aggregates {
3130 f(aggregate.expr.as_mut())?;
3131 }
3132 }
3133 TopK {
3134 input: _,
3135 group_key: _,
3136 order_key: _,
3137 limit,
3138 offset,
3139 expected_group_size: _,
3140 } => {
3141 if let Some(limit) = limit {
3142 f(limit)?
3143 }
3144 f(offset)?
3145 }
3146 Distinct { input: _ }
3147 | Negate { input: _ }
3148 | Threshold { input: _ }
3149 | Union { base: _, inputs: _ } => (),
3150 }
3151 Ok(())
3152 }
3153}
3154
3155impl HirScalarExpr {
3156 pub fn name(&self) -> Option<Arc<str>> {
3157 use HirScalarExpr::*;
3158 match self {
3159 Column(_, name)
3160 | Parameter(_, name)
3161 | Literal(_, _, name)
3162 | CallUnmaterializable(_, name)
3163 | CallUnary { name, .. }
3164 | CallBinary { name, .. }
3165 | CallVariadic { name, .. }
3166 | If { name, .. }
3167 | Exists(_, name)
3168 | Select(_, name)
3169 | Windowing(_, name) => name.0.clone(),
3170 }
3171 }
3172
3173 pub fn bind_parameters(
3176 &mut self,
3177 scx: &StatementContext,
3178 lifetime: QueryLifetime,
3179 params: &Params,
3180 ) -> Result<(), PlanError> {
3181 #[allow(deprecated)]
3182 self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3183 if let HirScalarExpr::Parameter(n, name) = e {
3184 let datum = match params.datums.iter().nth(*n - 1) {
3185 None => return Err(PlanError::UnknownParameter(*n)),
3186 Some(datum) => datum,
3187 };
3188 let scalar_type = ¶ms.execute_types[*n - 1];
3189 let row = Row::pack([datum]);
3190 let column_type = scalar_type.clone().nullable(datum.is_null());
3191
3192 let name = if let Some(name) = &name.0 {
3193 Some(Arc::clone(name))
3194 } else {
3195 Some(Arc::from(format!("${n}")))
3196 };
3197
3198 let qcx = QueryContext::root(scx, lifetime);
3199 let ecx = execute_expr_context(&qcx);
3200
3201 *e = plan_cast(
3202 &ecx,
3203 *EXECUTE_CAST_CONTEXT,
3204 HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3205 ¶ms.expected_types[*n - 1],
3206 )
3207 .expect("checked in plan_params");
3208 }
3209 Ok(())
3210 })
3211 }
3212
3213 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3224 #[allow(deprecated)]
3225 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3226 e: &mut HirScalarExpr|
3227 -> Result<(), ()> {
3228 if let HirScalarExpr::Parameter(i, _name) = e {
3229 *e = params[*i - 1].clone();
3230 e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3233 if col.level >= d {
3234 col.level += depth
3235 }
3236 });
3237 }
3238 Ok(())
3239 });
3240 }
3241
3242 pub fn contains_temporal(&self) -> bool {
3244 let mut contains = false;
3245 #[allow(deprecated)]
3246 self.visit_post_nolimit(&mut |e| {
3247 if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3248 contains = true;
3249 }
3250 });
3251 contains
3252 }
3253
3254 pub fn contains_unmaterializable(&self) -> bool {
3256 let mut contains = false;
3257 #[allow(deprecated)]
3258 self.visit_post_nolimit(&mut |e| {
3259 if let Self::CallUnmaterializable(_, _) = e {
3260 contains = true;
3261 }
3262 });
3263 contains
3264 }
3265
3266 pub fn contains_unmaterializable_except_temporal(&self) -> bool {
3269 let mut contains = false;
3270 #[allow(deprecated)]
3271 self.visit_post_nolimit(&mut |e| {
3272 if let Self::CallUnmaterializable(f, _) = e {
3273 if *f != UnmaterializableFunc::MzNow {
3274 contains = true;
3275 }
3276 }
3277 });
3278 contains
3279 }
3280
3281 pub fn column(index: usize) -> HirScalarExpr {
3285 HirScalarExpr::Column(
3286 ColumnRef {
3287 level: 0,
3288 column: index,
3289 },
3290 TreatAsEqual(None),
3291 )
3292 }
3293
3294 pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3296 HirScalarExpr::Column(cr, TreatAsEqual(None))
3297 }
3298
3299 pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3302 HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3303 }
3304
3305 pub fn parameter(n: usize) -> HirScalarExpr {
3306 HirScalarExpr::Parameter(n, TreatAsEqual(None))
3307 }
3308
3309 pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3310 let col_type = scalar_type.nullable(datum.is_null());
3311 soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3312 let row = Row::pack([datum]);
3313 HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3314 }
3315
3316 pub fn literal_true() -> HirScalarExpr {
3317 HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3318 }
3319
3320 pub fn literal_false() -> HirScalarExpr {
3321 HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3322 }
3323
3324 pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3325 HirScalarExpr::literal(Datum::Null, scalar_type)
3326 }
3327
3328 pub fn literal_1d_array(
3329 datums: Vec<Datum>,
3330 element_scalar_type: SqlScalarType,
3331 ) -> Result<HirScalarExpr, PlanError> {
3332 let scalar_type = match element_scalar_type {
3333 SqlScalarType::Array(_) => {
3334 sql_bail!("cannot build array from array type");
3335 }
3336 typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3337 };
3338
3339 let mut row = Row::default();
3340 row.packer()
3341 .try_push_array(
3342 &[ArrayDimension {
3343 lower_bound: 1,
3344 length: datums.len(),
3345 }],
3346 datums,
3347 )
3348 .expect("array constructed to be valid");
3349
3350 Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3351 }
3352
3353 pub fn as_literal(&self) -> Option<Datum<'_>> {
3354 if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3355 Some(row.unpack_first())
3356 } else {
3357 None
3358 }
3359 }
3360
3361 pub fn is_literal_true(&self) -> bool {
3362 Some(Datum::True) == self.as_literal()
3363 }
3364
3365 pub fn is_literal_false(&self) -> bool {
3366 Some(Datum::False) == self.as_literal()
3367 }
3368
3369 pub fn is_literal_null(&self) -> bool {
3370 Some(Datum::Null) == self.as_literal()
3371 }
3372
3373 pub fn is_constant(&self) -> bool {
3376 let mut worklist = vec![self];
3377 while let Some(expr) = worklist.pop() {
3378 match expr {
3379 Self::Literal(..) => {
3380 }
3382 Self::CallUnary { expr, .. } => {
3383 worklist.push(expr);
3384 }
3385 Self::CallBinary {
3386 func: _,
3387 expr1,
3388 expr2,
3389 name: _,
3390 } => {
3391 worklist.push(expr1);
3392 worklist.push(expr2);
3393 }
3394 Self::CallVariadic {
3395 func: _,
3396 exprs,
3397 name: _,
3398 } => {
3399 worklist.extend(exprs.iter());
3400 }
3401 Self::If {
3403 cond,
3404 then,
3405 els,
3406 name: _,
3407 } => {
3408 worklist.push(cond);
3409 worklist.push(then);
3410 worklist.push(els);
3411 }
3412 _ => {
3413 return false; }
3415 }
3416 }
3417 true
3418 }
3419
3420 pub fn call_unary(self, func: UnaryFunc) -> Self {
3421 HirScalarExpr::CallUnary {
3422 func,
3423 expr: Box::new(self),
3424 name: NameMetadata::default(),
3425 }
3426 }
3427
3428 pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3429 HirScalarExpr::CallBinary {
3430 func: func.into(),
3431 expr1: Box::new(self),
3432 expr2: Box::new(other),
3433 name: NameMetadata::default(),
3434 }
3435 }
3436
3437 pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3438 HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3439 }
3440
3441 pub fn call_variadic<V: Into<VariadicFunc>>(func: V, exprs: Vec<Self>) -> Self {
3442 HirScalarExpr::CallVariadic {
3443 func: func.into(),
3444 exprs,
3445 name: NameMetadata::default(),
3446 }
3447 }
3448
3449 pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3450 HirScalarExpr::If {
3451 cond: Box::new(cond),
3452 then: Box::new(then),
3453 els: Box::new(els),
3454 name: NameMetadata::default(),
3455 }
3456 }
3457
3458 pub fn windowing(expr: WindowExpr) -> Self {
3459 HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3460 }
3461
3462 pub fn or(self, other: Self) -> Self {
3463 HirScalarExpr::call_variadic(Or, vec![self, other])
3464 }
3465
3466 pub fn and(self, other: Self) -> Self {
3467 HirScalarExpr::call_variadic(And, vec![self, other])
3468 }
3469
3470 pub fn not(self) -> Self {
3471 self.call_unary(UnaryFunc::Not(func::Not))
3472 }
3473
3474 pub fn call_is_null(self) -> Self {
3475 self.call_unary(UnaryFunc::IsNull(func::IsNull))
3476 }
3477
3478 pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3480 match args.len() {
3481 0 => HirScalarExpr::literal_true(), 1 => args.swap_remove(0),
3483 _ => HirScalarExpr::call_variadic(And, args),
3484 }
3485 }
3486
3487 pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3489 match args.len() {
3490 0 => HirScalarExpr::literal_false(), 1 => args.swap_remove(0),
3492 _ => HirScalarExpr::call_variadic(Or, args),
3493 }
3494 }
3495
3496 pub fn take(&mut self) -> Self {
3497 mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3498 }
3499
3500 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3501 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3507 where
3508 F: FnMut(usize, &ColumnRef),
3509 {
3510 #[allow(deprecated)]
3511 let _ = self.visit_recursively(depth, &mut |depth: usize,
3512 e: &HirScalarExpr|
3513 -> Result<(), ()> {
3514 if let HirScalarExpr::Column(col, _name) = e {
3515 f(depth, col)
3516 }
3517 Ok(())
3518 });
3519 }
3520
3521 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3522 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3524 where
3525 F: FnMut(usize, &mut ColumnRef),
3526 {
3527 #[allow(deprecated)]
3528 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3529 e: &mut HirScalarExpr|
3530 -> Result<(), ()> {
3531 if let HirScalarExpr::Column(col, _name) = e {
3532 f(depth, col)
3533 }
3534 Ok(())
3535 });
3536 }
3537
3538 pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3544 where
3545 F: FnMut(usize),
3546 {
3547 #[allow(deprecated)]
3548 let _ = self.visit_recursively(0, &mut |depth: usize,
3549 e: &HirScalarExpr|
3550 -> Result<(), ()> {
3551 if let HirScalarExpr::Column(col, _name) = e {
3552 if col.level == depth {
3553 f(col.column)
3554 }
3555 }
3556 Ok(())
3557 });
3558 }
3559
3560 pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3562 where
3563 F: FnMut(&mut usize),
3564 {
3565 #[allow(deprecated)]
3566 let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3567 e: &mut HirScalarExpr|
3568 -> Result<(), ()> {
3569 if let HirScalarExpr::Column(col, _name) = e {
3570 if col.level == depth {
3571 f(&mut col.column)
3572 }
3573 }
3574 Ok(())
3575 });
3576 }
3577
3578 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3579 pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3583 where
3584 F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3585 {
3586 match self {
3587 HirScalarExpr::Literal(..)
3588 | HirScalarExpr::Parameter(..)
3589 | HirScalarExpr::CallUnmaterializable(..)
3590 | HirScalarExpr::Column(..) => (),
3591 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3592 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3593 expr1.visit_recursively(depth, f)?;
3594 expr2.visit_recursively(depth, f)?;
3595 }
3596 HirScalarExpr::CallVariadic { exprs, .. } => {
3597 for expr in exprs {
3598 expr.visit_recursively(depth, f)?;
3599 }
3600 }
3601 HirScalarExpr::If {
3602 cond,
3603 then,
3604 els,
3605 name: _,
3606 } => {
3607 cond.visit_recursively(depth, f)?;
3608 then.visit_recursively(depth, f)?;
3609 els.visit_recursively(depth, f)?;
3610 }
3611 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3612 #[allow(deprecated)]
3613 expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3614 e.visit_recursively(depth, f)
3615 })?;
3616 }
3617 HirScalarExpr::Windowing(expr, _name) => {
3618 expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3619 }
3620 }
3621 f(depth, self)
3622 }
3623
3624 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3625 pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3627 where
3628 F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3629 {
3630 match self {
3631 HirScalarExpr::Literal(..)
3632 | HirScalarExpr::Parameter(..)
3633 | HirScalarExpr::CallUnmaterializable(..)
3634 | HirScalarExpr::Column(..) => (),
3635 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3636 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3637 expr1.visit_recursively_mut(depth, f)?;
3638 expr2.visit_recursively_mut(depth, f)?;
3639 }
3640 HirScalarExpr::CallVariadic { exprs, .. } => {
3641 for expr in exprs {
3642 expr.visit_recursively_mut(depth, f)?;
3643 }
3644 }
3645 HirScalarExpr::If {
3646 cond,
3647 then,
3648 els,
3649 name: _,
3650 } => {
3651 cond.visit_recursively_mut(depth, f)?;
3652 then.visit_recursively_mut(depth, f)?;
3653 els.visit_recursively_mut(depth, f)?;
3654 }
3655 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3656 #[allow(deprecated)]
3657 expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3658 e.visit_recursively_mut(depth, f)
3659 })?;
3660 }
3661 HirScalarExpr::Windowing(expr, _name) => {
3662 expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3663 }
3664 }
3665 f(depth, self)
3666 }
3667
3668 fn simplify_to_literal(self) -> Option<Row> {
3677 let mut expr = self
3678 .lower_uncorrelated(crate::plan::lowering::Config::default())
3679 .ok()?;
3680 expr.reduce(&[]);
3684 match expr {
3685 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3686 _ => None,
3687 }
3688 }
3689
3690 fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3703 let mut expr = self
3704 .lower_uncorrelated(crate::plan::lowering::Config::default())
3705 .map_err(|err| {
3706 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3707 })?;
3708 expr.reduce(&[]);
3712 match expr {
3713 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3714 mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3715 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3716 ),
3717 _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3718 "Not a constant".to_string(),
3719 )),
3720 }
3721 }
3722
3723 pub fn into_literal_int64(self) -> Option<i64> {
3732 self.simplify_to_literal().and_then(|row| {
3733 let datum = row.unpack_first();
3734 if datum.is_null() {
3735 None
3736 } else {
3737 Some(datum.unwrap_int64())
3738 }
3739 })
3740 }
3741
3742 pub fn into_literal_string(self) -> Option<String> {
3751 self.simplify_to_literal().and_then(|row| {
3752 let datum = row.unpack_first();
3753 if datum.is_null() {
3754 None
3755 } else {
3756 Some(datum.unwrap_str().to_owned())
3757 }
3758 })
3759 }
3760
3761 pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3774 self.simplify_to_literal().and_then(|row| {
3775 let datum = row.unpack_first();
3776 if datum.is_null() {
3777 None
3778 } else {
3779 Some(datum.unwrap_mz_timestamp())
3780 }
3781 })
3782 }
3783
3784 pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3796 if !self.is_constant() {
3802 return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3803 "Expected a constant expression, got {}",
3804 self
3805 )));
3806 }
3807 self.clone()
3808 .simplify_to_literal_with_result()
3809 .and_then(|row| {
3810 let datum = row.unpack_first();
3811 if datum.is_null() {
3812 Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3813 "Expected an expression that evaluates to a non-null value, got {}",
3814 self
3815 )))
3816 } else {
3817 Ok(datum.unwrap_int64())
3818 }
3819 })
3820 }
3821
3822 pub fn contains_parameters(&self) -> bool {
3823 let mut contains_parameters = false;
3824 #[allow(deprecated)]
3825 let _ = self.visit_recursively(0, &mut |_depth: usize,
3826 expr: &HirScalarExpr|
3827 -> Result<(), ()> {
3828 if let HirScalarExpr::Parameter(..) = expr {
3829 contains_parameters = true;
3830 }
3831 Ok(())
3832 });
3833 contains_parameters
3834 }
3835}
3836
3837impl VisitChildren<Self> for HirScalarExpr {
3838 fn visit_children<F>(&self, mut f: F)
3839 where
3840 F: FnMut(&Self),
3841 {
3842 use HirScalarExpr::*;
3843 match self {
3844 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3845 CallUnary { expr, .. } => f(expr),
3846 CallBinary { expr1, expr2, .. } => {
3847 f(expr1);
3848 f(expr2);
3849 }
3850 CallVariadic { exprs, .. } => {
3851 for expr in exprs {
3852 f(expr);
3853 }
3854 }
3855 If {
3856 cond,
3857 then,
3858 els,
3859 name: _,
3860 } => {
3861 f(cond);
3862 f(then);
3863 f(els);
3864 }
3865 Exists(..) | Select(..) => (),
3866 Windowing(expr, _name) => expr.visit_children(f),
3867 }
3868 }
3869
3870 fn visit_mut_children<F>(&mut self, mut f: F)
3871 where
3872 F: FnMut(&mut Self),
3873 {
3874 use HirScalarExpr::*;
3875 match self {
3876 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3877 CallUnary { expr, .. } => f(expr),
3878 CallBinary { expr1, expr2, .. } => {
3879 f(expr1);
3880 f(expr2);
3881 }
3882 CallVariadic { exprs, .. } => {
3883 for expr in exprs {
3884 f(expr);
3885 }
3886 }
3887 If {
3888 cond,
3889 then,
3890 els,
3891 name: _,
3892 } => {
3893 f(cond);
3894 f(then);
3895 f(els);
3896 }
3897 Exists(..) | Select(..) => (),
3898 Windowing(expr, _name) => expr.visit_mut_children(f),
3899 }
3900 }
3901
3902 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3903 where
3904 F: FnMut(&Self) -> Result<(), E>,
3905 E: From<RecursionLimitError>,
3906 {
3907 use HirScalarExpr::*;
3908 match self {
3909 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3910 CallUnary { expr, .. } => f(expr)?,
3911 CallBinary { expr1, expr2, .. } => {
3912 f(expr1)?;
3913 f(expr2)?;
3914 }
3915 CallVariadic { exprs, .. } => {
3916 for expr in exprs {
3917 f(expr)?;
3918 }
3919 }
3920 If {
3921 cond,
3922 then,
3923 els,
3924 name: _,
3925 } => {
3926 f(cond)?;
3927 f(then)?;
3928 f(els)?;
3929 }
3930 Exists(..) | Select(..) => (),
3931 Windowing(expr, _name) => expr.try_visit_children(f)?,
3932 }
3933 Ok(())
3934 }
3935
3936 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3937 where
3938 F: FnMut(&mut Self) -> Result<(), E>,
3939 E: From<RecursionLimitError>,
3940 {
3941 use HirScalarExpr::*;
3942 match self {
3943 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3944 CallUnary { expr, .. } => f(expr)?,
3945 CallBinary { expr1, expr2, .. } => {
3946 f(expr1)?;
3947 f(expr2)?;
3948 }
3949 CallVariadic { exprs, .. } => {
3950 for expr in exprs {
3951 f(expr)?;
3952 }
3953 }
3954 If {
3955 cond,
3956 then,
3957 els,
3958 name: _,
3959 } => {
3960 f(cond)?;
3961 f(then)?;
3962 f(els)?;
3963 }
3964 Exists(..) | Select(..) => (),
3965 Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3966 }
3967 Ok(())
3968 }
3969}
3970
3971impl AbstractExpr for HirScalarExpr {
3972 type Type = SqlColumnType;
3973
3974 fn typ(
3975 &self,
3976 outers: &[SqlRelationType],
3977 inner: &SqlRelationType,
3978 params: &BTreeMap<usize, SqlScalarType>,
3979 ) -> Self::Type {
3980 stack::maybe_grow(|| match self {
3981 HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3982 if *level == 0 {
3983 inner.column_types[*column].clone()
3984 } else {
3985 outers[*level - 1].column_types[*column].clone()
3986 }
3987 }
3988 HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3989 HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3990 HirScalarExpr::CallUnmaterializable(func, _name) => func.output_sql_type(),
3991 HirScalarExpr::CallUnary {
3992 expr,
3993 func,
3994 name: _,
3995 } => func.output_sql_type(expr.typ(outers, inner, params)),
3996 HirScalarExpr::CallBinary {
3997 expr1,
3998 expr2,
3999 func,
4000 name: _,
4001 } => func.output_sql_type(&[
4002 expr1.typ(outers, inner, params),
4003 expr2.typ(outers, inner, params),
4004 ]),
4005 HirScalarExpr::CallVariadic {
4006 exprs,
4007 func,
4008 name: _,
4009 } => func.output_sql_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
4010 HirScalarExpr::If {
4011 cond: _,
4012 then,
4013 els,
4014 name: _,
4015 } => {
4016 let then_type = then.typ(outers, inner, params);
4017 let else_type = els.typ(outers, inner, params);
4018 then_type.sql_union(&else_type).unwrap() }
4020 HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
4021 HirScalarExpr::Select(expr, _name) => {
4022 let mut outers = outers.to_vec();
4023 outers.insert(0, inner.clone());
4024 expr.typ(&outers, params)
4025 .column_types
4026 .into_element()
4027 .nullable(true)
4028 }
4029 HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
4030 })
4031 }
4032}
4033
4034impl AggregateExpr {
4035 pub fn typ(
4036 &self,
4037 outers: &[SqlRelationType],
4038 inner: &SqlRelationType,
4039 params: &BTreeMap<usize, SqlScalarType>,
4040 ) -> SqlColumnType {
4041 self.func
4042 .output_sql_type(self.expr.typ(outers, inner, params))
4043 }
4044
4045 pub fn is_count_asterisk(&self) -> bool {
4053 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
4054 }
4055}