1use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26 BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50 type Relation = HirRelationExpr;
51 type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55 fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56 if *all {
57 let rhs = rhs.negate();
58 HirRelationExpr::union(lhs, rhs).threshold()
59 } else {
60 let lhs = lhs.distinct();
61 let rhs = rhs.distinct().negate();
62 HirRelationExpr::union(lhs, rhs).threshold()
63 }
64 }
65
66 fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67 let mut result = None;
68
69 use HirRelationExpr::*;
70 if let Threshold { input } = expr {
71 if let Union { base: lhs, inputs } = input.as_ref() {
72 if let [rhs] = &inputs[..] {
73 if let Negate { input: rhs } = rhs {
74 match (lhs.as_ref(), rhs.as_ref()) {
75 (Distinct { input: lhs }, Distinct { input: rhs }) => {
76 let all = false;
77 let lhs = lhs.as_ref();
78 let rhs = rhs.as_ref();
79 result = Some(Except { all, lhs, rhs })
80 }
81 (lhs, rhs) => {
82 let all = true;
83 result = Some(Except { all, lhs, rhs })
84 }
85 }
86 }
87 }
88 }
89 }
90
91 result
92 }
93}
94
95#[derive(
96 Debug,
97 Clone,
98 PartialEq,
99 Eq,
100 PartialOrd,
101 Ord,
102 Hash,
103 Serialize,
104 Deserialize
105)]
106pub enum HirRelationExpr {
108 Constant {
109 rows: Vec<Row>,
110 typ: SqlRelationType,
111 },
112 Get {
113 id: mz_expr::Id,
114 typ: SqlRelationType,
115 },
116 LetRec {
118 limit: Option<LetRecLimit>,
120 bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
122 body: Box<HirRelationExpr>,
124 },
125 Let {
127 name: String,
128 id: mz_expr::LocalId,
130 value: Box<HirRelationExpr>,
132 body: Box<HirRelationExpr>,
134 },
135 Project {
136 input: Box<HirRelationExpr>,
137 outputs: Vec<usize>,
138 },
139 Map {
140 input: Box<HirRelationExpr>,
141 scalars: Vec<HirScalarExpr>,
142 },
143 CallTable {
144 func: TableFunc,
145 exprs: Vec<HirScalarExpr>,
146 },
147 Filter {
148 input: Box<HirRelationExpr>,
149 predicates: Vec<HirScalarExpr>,
150 },
151 Join {
154 left: Box<HirRelationExpr>,
155 right: Box<HirRelationExpr>,
156 on: HirScalarExpr,
157 kind: JoinKind,
158 },
159 Reduce {
163 input: Box<HirRelationExpr>,
164 group_key: Vec<usize>,
165 aggregates: Vec<AggregateExpr>,
166 expected_group_size: Option<u64>,
167 },
168 Distinct {
169 input: Box<HirRelationExpr>,
170 },
171 TopK {
173 input: Box<HirRelationExpr>,
175 group_key: Vec<usize>,
177 order_key: Vec<ColumnOrder>,
179 limit: Option<HirScalarExpr>,
188 offset: HirScalarExpr,
193 expected_group_size: Option<u64>,
195 },
196 Negate {
197 input: Box<HirRelationExpr>,
198 },
199 Threshold {
201 input: Box<HirRelationExpr>,
202 },
203 Union {
204 base: Box<HirRelationExpr>,
205 inputs: Vec<HirRelationExpr>,
206 },
207}
208
209pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
211
212#[derive(
213 Debug,
214 Clone,
215 PartialEq,
216 Eq,
217 PartialOrd,
218 Ord,
219 Hash,
220 Serialize,
221 Deserialize
222)]
223pub enum HirScalarExpr {
225 Column(ColumnRef, NameMetadata),
229 Parameter(usize, NameMetadata),
230 Literal(Row, SqlColumnType, NameMetadata),
231 CallUnmaterializable(UnmaterializableFunc, NameMetadata),
232 CallUnary {
233 func: UnaryFunc,
234 expr: Box<HirScalarExpr>,
235 name: NameMetadata,
236 },
237 CallBinary {
238 func: BinaryFunc,
239 expr1: Box<HirScalarExpr>,
240 expr2: Box<HirScalarExpr>,
241 name: NameMetadata,
242 },
243 CallVariadic {
244 func: VariadicFunc,
245 exprs: Vec<HirScalarExpr>,
246 name: NameMetadata,
247 },
248 If {
249 cond: Box<HirScalarExpr>,
250 then: Box<HirScalarExpr>,
251 els: Box<HirScalarExpr>,
252 name: NameMetadata,
253 },
254 Exists(Box<HirRelationExpr>, NameMetadata),
256 Select(Box<HirRelationExpr>, NameMetadata),
261 Windowing(WindowExpr, NameMetadata),
262}
263
264#[derive(
265 Debug,
266 Clone,
267 PartialEq,
268 Eq,
269 PartialOrd,
270 Ord,
271 Hash,
272 Serialize,
273 Deserialize
274)]
275pub struct WindowExpr {
278 pub func: WindowExprType,
279 pub partition_by: Vec<HirScalarExpr>,
280 pub order_by: Vec<HirScalarExpr>,
291}
292
293impl WindowExpr {
294 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
295 where
296 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
297 {
298 #[allow(deprecated)]
299 self.func.visit_expressions(f)?;
300 for expr in self.partition_by.iter() {
301 f(expr)?;
302 }
303 for expr in self.order_by.iter() {
304 f(expr)?;
305 }
306 Ok(())
307 }
308
309 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
310 where
311 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
312 {
313 #[allow(deprecated)]
314 self.func.visit_expressions_mut(f)?;
315 for expr in self.partition_by.iter_mut() {
316 f(expr)?;
317 }
318 for expr in self.order_by.iter_mut() {
319 f(expr)?;
320 }
321 Ok(())
322 }
323}
324
325impl VisitChildren<HirScalarExpr> for WindowExpr {
326 fn visit_children<F>(&self, mut f: F)
327 where
328 F: FnMut(&HirScalarExpr),
329 {
330 self.func.visit_children(&mut f);
331 for expr in self.partition_by.iter() {
332 f(expr);
333 }
334 for expr in self.order_by.iter() {
335 f(expr);
336 }
337 }
338
339 fn visit_mut_children<F>(&mut self, mut f: F)
340 where
341 F: FnMut(&mut HirScalarExpr),
342 {
343 self.func.visit_mut_children(&mut f);
344 for expr in self.partition_by.iter_mut() {
345 f(expr);
346 }
347 for expr in self.order_by.iter_mut() {
348 f(expr);
349 }
350 }
351
352 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
353 where
354 F: FnMut(&HirScalarExpr) -> Result<(), E>,
355 E: From<RecursionLimitError>,
356 {
357 self.func.try_visit_children(&mut f)?;
358 for expr in self.partition_by.iter() {
359 f(expr)?;
360 }
361 for expr in self.order_by.iter() {
362 f(expr)?;
363 }
364 Ok(())
365 }
366
367 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
368 where
369 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
370 E: From<RecursionLimitError>,
371 {
372 self.func.try_visit_mut_children(&mut f)?;
373 for expr in self.partition_by.iter_mut() {
374 f(expr)?;
375 }
376 for expr in self.order_by.iter_mut() {
377 f(expr)?;
378 }
379 Ok(())
380 }
381}
382
383#[derive(
384 Debug,
385 Clone,
386 PartialEq,
387 Eq,
388 PartialOrd,
389 Ord,
390 Hash,
391 Serialize,
392 Deserialize
393)]
394pub enum WindowExprType {
411 Scalar(ScalarWindowExpr),
412 Value(ValueWindowExpr),
413 Aggregate(AggregateWindowExpr),
414}
415
416impl WindowExprType {
417 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
418 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
419 where
420 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
421 {
422 #[allow(deprecated)]
423 match self {
424 Self::Scalar(expr) => expr.visit_expressions(f),
425 Self::Value(expr) => expr.visit_expressions(f),
426 Self::Aggregate(expr) => expr.visit_expressions(f),
427 }
428 }
429
430 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
431 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
432 where
433 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
434 {
435 #[allow(deprecated)]
436 match self {
437 Self::Scalar(expr) => expr.visit_expressions_mut(f),
438 Self::Value(expr) => expr.visit_expressions_mut(f),
439 Self::Aggregate(expr) => expr.visit_expressions_mut(f),
440 }
441 }
442
443 fn typ(
444 &self,
445 outers: &[SqlRelationType],
446 inner: &SqlRelationType,
447 params: &BTreeMap<usize, SqlScalarType>,
448 ) -> SqlColumnType {
449 match self {
450 Self::Scalar(expr) => expr.typ(outers, inner, params),
451 Self::Value(expr) => expr.typ(outers, inner, params),
452 Self::Aggregate(expr) => expr.typ(outers, inner, params),
453 }
454 }
455}
456
457impl VisitChildren<HirScalarExpr> for WindowExprType {
458 fn visit_children<F>(&self, f: F)
459 where
460 F: FnMut(&HirScalarExpr),
461 {
462 match self {
463 Self::Scalar(_) => (),
464 Self::Value(expr) => expr.visit_children(f),
465 Self::Aggregate(expr) => expr.visit_children(f),
466 }
467 }
468
469 fn visit_mut_children<F>(&mut self, f: F)
470 where
471 F: FnMut(&mut HirScalarExpr),
472 {
473 match self {
474 Self::Scalar(_) => (),
475 Self::Value(expr) => expr.visit_mut_children(f),
476 Self::Aggregate(expr) => expr.visit_mut_children(f),
477 }
478 }
479
480 fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
481 where
482 F: FnMut(&HirScalarExpr) -> Result<(), E>,
483 E: From<RecursionLimitError>,
484 {
485 match self {
486 Self::Scalar(_) => Ok(()),
487 Self::Value(expr) => expr.try_visit_children(f),
488 Self::Aggregate(expr) => expr.try_visit_children(f),
489 }
490 }
491
492 fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
493 where
494 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
495 E: From<RecursionLimitError>,
496 {
497 match self {
498 Self::Scalar(_) => Ok(()),
499 Self::Value(expr) => expr.try_visit_mut_children(f),
500 Self::Aggregate(expr) => expr.try_visit_mut_children(f),
501 }
502 }
503}
504
505#[derive(
506 Debug,
507 Clone,
508 PartialEq,
509 Eq,
510 PartialOrd,
511 Ord,
512 Hash,
513 Serialize,
514 Deserialize
515)]
516pub struct ScalarWindowExpr {
517 pub func: ScalarWindowFunc,
518 pub order_by: Vec<ColumnOrder>,
519}
520
521impl ScalarWindowExpr {
522 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
523 pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
524 where
525 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
526 {
527 match self.func {
528 ScalarWindowFunc::RowNumber => {}
529 ScalarWindowFunc::Rank => {}
530 ScalarWindowFunc::DenseRank => {}
531 }
532 Ok(())
533 }
534
535 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
536 pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
537 where
538 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
539 {
540 match self.func {
541 ScalarWindowFunc::RowNumber => {}
542 ScalarWindowFunc::Rank => {}
543 ScalarWindowFunc::DenseRank => {}
544 }
545 Ok(())
546 }
547
548 fn typ(
549 &self,
550 _outers: &[SqlRelationType],
551 _inner: &SqlRelationType,
552 _params: &BTreeMap<usize, SqlScalarType>,
553 ) -> SqlColumnType {
554 self.func.output_type()
555 }
556
557 pub fn into_expr(self) -> mz_expr::AggregateFunc {
558 match self.func {
559 ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
560 order_by: self.order_by,
561 },
562 ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
563 order_by: self.order_by,
564 },
565 ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
566 order_by: self.order_by,
567 },
568 }
569 }
570}
571
572#[derive(
573 Debug,
574 Clone,
575 PartialEq,
576 Eq,
577 PartialOrd,
578 Ord,
579 Hash,
580 Serialize,
581 Deserialize
582)]
583pub enum ScalarWindowFunc {
585 RowNumber,
586 Rank,
587 DenseRank,
588}
589
590impl Display for ScalarWindowFunc {
591 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
592 match self {
593 ScalarWindowFunc::RowNumber => write!(f, "row_number"),
594 ScalarWindowFunc::Rank => write!(f, "rank"),
595 ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
596 }
597 }
598}
599
600impl ScalarWindowFunc {
601 pub fn output_type(&self) -> SqlColumnType {
602 match self {
603 ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
604 ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
605 ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
606 }
607 }
608}
609
610#[derive(
611 Debug,
612 Clone,
613 PartialEq,
614 Eq,
615 PartialOrd,
616 Ord,
617 Hash,
618 Serialize,
619 Deserialize
620)]
621pub struct ValueWindowExpr {
622 pub func: ValueWindowFunc,
623 pub args: Box<HirScalarExpr>,
629 pub order_by: Vec<ColumnOrder>,
631 pub window_frame: WindowFrame,
632 pub ignore_nulls: bool,
633}
634
635impl Display for ValueWindowFunc {
636 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
637 match self {
638 ValueWindowFunc::Lag => write!(f, "lag"),
639 ValueWindowFunc::Lead => write!(f, "lead"),
640 ValueWindowFunc::FirstValue => write!(f, "first_value"),
641 ValueWindowFunc::LastValue => write!(f, "last_value"),
642 ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
643 }
644 }
645}
646
647impl ValueWindowExpr {
648 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
649 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
650 where
651 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
652 {
653 f(&self.args)
654 }
655
656 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
657 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
658 where
659 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
660 {
661 f(&mut self.args)
662 }
663
664 fn typ(
665 &self,
666 outers: &[SqlRelationType],
667 inner: &SqlRelationType,
668 params: &BTreeMap<usize, SqlScalarType>,
669 ) -> SqlColumnType {
670 self.func.output_type(self.args.typ(outers, inner, params))
671 }
672
673 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
675 (
676 self.args,
677 self.func
678 .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
679 )
680 }
681}
682
683impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
684 fn visit_children<F>(&self, mut f: F)
685 where
686 F: FnMut(&HirScalarExpr),
687 {
688 f(&self.args)
689 }
690
691 fn visit_mut_children<F>(&mut self, mut f: F)
692 where
693 F: FnMut(&mut HirScalarExpr),
694 {
695 f(&mut self.args)
696 }
697
698 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
699 where
700 F: FnMut(&HirScalarExpr) -> Result<(), E>,
701 E: From<RecursionLimitError>,
702 {
703 f(&self.args)
704 }
705
706 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
707 where
708 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
709 E: From<RecursionLimitError>,
710 {
711 f(&mut self.args)
712 }
713}
714
715#[derive(
716 Debug,
717 Clone,
718 PartialEq,
719 Eq,
720 PartialOrd,
721 Ord,
722 Hash,
723 Serialize,
724 Deserialize
725)]
726pub enum ValueWindowFunc {
728 Lag,
729 Lead,
730 FirstValue,
731 LastValue,
732 Fused(Vec<ValueWindowFunc>),
733}
734
735impl ValueWindowFunc {
736 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
737 match self {
738 ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
739 input_type.scalar_type.unwrap_record_element_type()[0]
741 .clone()
742 .nullable(true)
743 }
744 ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
745 input_type.scalar_type.nullable(true)
746 }
747 ValueWindowFunc::Fused(funcs) => {
748 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
749 SqlScalarType::Record {
750 fields: funcs
751 .iter()
752 .zip_eq(input_types)
753 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
754 .collect(),
755 custom_id: None,
756 }
757 .nullable(false)
758 }
759 }
760 }
761
762 pub fn into_expr(
763 self,
764 order_by: Vec<ColumnOrder>,
765 window_frame: WindowFrame,
766 ignore_nulls: bool,
767 ) -> mz_expr::AggregateFunc {
768 match self {
769 ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
771 order_by,
772 lag_lead: mz_expr::LagLeadType::Lag,
773 ignore_nulls,
774 },
775 ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
776 order_by,
777 lag_lead: mz_expr::LagLeadType::Lead,
778 ignore_nulls,
779 },
780 ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
781 order_by,
782 window_frame,
783 },
784 ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
785 order_by,
786 window_frame,
787 },
788 ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
789 funcs: funcs
790 .into_iter()
791 .map(|func| {
792 func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
793 })
794 .collect(),
795 order_by,
796 },
797 }
798 }
799}
800
801#[derive(
802 Debug,
803 Clone,
804 PartialEq,
805 Eq,
806 PartialOrd,
807 Ord,
808 Hash,
809 Serialize,
810 Deserialize
811)]
812pub struct AggregateWindowExpr {
813 pub aggregate_expr: AggregateExpr,
814 pub order_by: Vec<ColumnOrder>,
815 pub window_frame: WindowFrame,
816}
817
818impl AggregateWindowExpr {
819 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
820 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
821 where
822 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
823 {
824 f(&self.aggregate_expr.expr)
825 }
826
827 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
828 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
829 where
830 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
831 {
832 f(&mut self.aggregate_expr.expr)
833 }
834
835 fn typ(
836 &self,
837 outers: &[SqlRelationType],
838 inner: &SqlRelationType,
839 params: &BTreeMap<usize, SqlScalarType>,
840 ) -> SqlColumnType {
841 self.aggregate_expr
842 .func
843 .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
844 }
845
846 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
847 if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
848 (
849 self.aggregate_expr.expr,
850 FusedWindowAggregate {
851 wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
852 order_by: self.order_by,
853 window_frame: self.window_frame,
854 },
855 )
856 } else {
857 (
858 self.aggregate_expr.expr,
859 WindowAggregate {
860 wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
861 order_by: self.order_by,
862 window_frame: self.window_frame,
863 },
864 )
865 }
866 }
867}
868
869impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
870 fn visit_children<F>(&self, mut f: F)
871 where
872 F: FnMut(&HirScalarExpr),
873 {
874 f(&self.aggregate_expr.expr)
875 }
876
877 fn visit_mut_children<F>(&mut self, mut f: F)
878 where
879 F: FnMut(&mut HirScalarExpr),
880 {
881 f(&mut self.aggregate_expr.expr)
882 }
883
884 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
885 where
886 F: FnMut(&HirScalarExpr) -> Result<(), E>,
887 E: From<RecursionLimitError>,
888 {
889 f(&self.aggregate_expr.expr)
890 }
891
892 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
893 where
894 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
895 E: From<RecursionLimitError>,
896 {
897 f(&mut self.aggregate_expr.expr)
898 }
899}
900
901#[derive(Clone, Debug)]
926pub enum CoercibleScalarExpr {
927 Coerced(HirScalarExpr),
928 Parameter(usize),
929 LiteralNull,
930 LiteralString(String),
931 LiteralRecord(Vec<CoercibleScalarExpr>),
932}
933
934impl CoercibleScalarExpr {
935 pub fn type_as(
936 self,
937 ecx: &ExprContext,
938 ty: &SqlScalarType,
939 ) -> Result<HirScalarExpr, PlanError> {
940 let expr = typeconv::plan_coerce(ecx, self, ty)?;
941 let expr_ty = ecx.scalar_type(&expr);
942 if ty != &expr_ty {
943 sql_bail!(
944 "{} must have type {}, not type {}",
945 ecx.name,
946 ecx.humanize_scalar_type(ty, false),
947 ecx.humanize_scalar_type(&expr_ty, false),
948 );
949 }
950 Ok(expr)
951 }
952
953 pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
954 typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
955 }
956
957 pub fn cast_to(
958 self,
959 ecx: &ExprContext,
960 ccx: CastContext,
961 ty: &SqlScalarType,
962 ) -> Result<HirScalarExpr, PlanError> {
963 let expr = typeconv::plan_coerce(ecx, self, ty)?;
964 typeconv::plan_cast(ecx, ccx, expr, ty)
965 }
966}
967
968#[derive(Clone, Debug)]
970pub enum CoercibleColumnType {
971 Coerced(SqlColumnType),
972 Record(Vec<CoercibleColumnType>),
973 Uncoerced,
974}
975
976impl CoercibleColumnType {
977 pub fn nullable(&self) -> bool {
979 match self {
980 CoercibleColumnType::Coerced(ct) => ct.nullable,
982
983 CoercibleColumnType::Record(_) => false,
985
986 CoercibleColumnType::Uncoerced => true,
989 }
990 }
991}
992
993#[derive(Clone, Debug)]
995pub enum CoercibleScalarType {
996 Coerced(SqlScalarType),
997 Record(Vec<CoercibleColumnType>),
998 Uncoerced,
999}
1000
1001impl CoercibleScalarType {
1002 pub fn is_coerced(&self) -> bool {
1004 matches!(self, CoercibleScalarType::Coerced(_))
1005 }
1006
1007 pub fn as_coerced(&self) -> Option<&SqlScalarType> {
1009 match self {
1010 CoercibleScalarType::Coerced(t) => Some(t),
1011 _ => None,
1012 }
1013 }
1014
1015 pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
1018 where
1019 F: FnOnce(SqlScalarType) -> SqlScalarType,
1020 {
1021 match self {
1022 CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
1023 _ => self,
1024 }
1025 }
1026
1027 pub fn force_coerced_if_record(&mut self) {
1034 fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
1035 let mut fields = vec![];
1036 for (i, uf) in uncoerced_fields.enumerate() {
1037 let name = ColumnName::from(format!("f{}", i + 1));
1038 let ty = match uf {
1039 CoercibleColumnType::Coerced(ty) => ty,
1040 CoercibleColumnType::Record(mut fields) => {
1041 convert(fields.drain(..)).nullable(false)
1042 }
1043 CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
1044 };
1045 fields.push((name, ty))
1046 }
1047 SqlScalarType::Record {
1048 fields: fields.into(),
1049 custom_id: None,
1050 }
1051 }
1052
1053 if let CoercibleScalarType::Record(fields) = self {
1054 *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
1055 }
1056 }
1057}
1058
1059pub trait AbstractExpr {
1063 type Type: AbstractColumnType;
1064
1065 fn typ(
1067 &self,
1068 outers: &[SqlRelationType],
1069 inner: &SqlRelationType,
1070 params: &BTreeMap<usize, SqlScalarType>,
1071 ) -> Self::Type;
1072}
1073
1074impl AbstractExpr for CoercibleScalarExpr {
1075 type Type = CoercibleColumnType;
1076
1077 fn typ(
1078 &self,
1079 outers: &[SqlRelationType],
1080 inner: &SqlRelationType,
1081 params: &BTreeMap<usize, SqlScalarType>,
1082 ) -> Self::Type {
1083 match self {
1084 CoercibleScalarExpr::Coerced(expr) => {
1085 CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
1086 }
1087 CoercibleScalarExpr::LiteralRecord(scalars) => {
1088 let fields = scalars
1089 .iter()
1090 .map(|s| s.typ(outers, inner, params))
1091 .collect();
1092 CoercibleColumnType::Record(fields)
1093 }
1094 _ => CoercibleColumnType::Uncoerced,
1095 }
1096 }
1097}
1098
1099pub trait AbstractColumnType {
1104 type AbstractScalarType;
1105
1106 fn scalar_type(self) -> Self::AbstractScalarType;
1109}
1110
1111impl AbstractColumnType for SqlColumnType {
1112 type AbstractScalarType = SqlScalarType;
1113
1114 fn scalar_type(self) -> Self::AbstractScalarType {
1115 self.scalar_type
1116 }
1117}
1118
1119impl AbstractColumnType for CoercibleColumnType {
1120 type AbstractScalarType = CoercibleScalarType;
1121
1122 fn scalar_type(self) -> Self::AbstractScalarType {
1123 match self {
1124 CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1125 CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1126 CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1127 }
1128 }
1129}
1130
1131impl From<HirScalarExpr> for CoercibleScalarExpr {
1132 fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1133 CoercibleScalarExpr::Coerced(expr)
1134 }
1135}
1136
1137#[derive(
1152 Debug,
1153 Clone,
1154 Copy,
1155 PartialEq,
1156 Eq,
1157 Hash,
1158 Ord,
1159 PartialOrd,
1160 Serialize,
1161 Deserialize
1162)]
1163pub struct ColumnRef {
1164 pub level: usize,
1166 pub column: usize,
1168}
1169
1170#[derive(
1171 Debug,
1172 Clone,
1173 PartialEq,
1174 Eq,
1175 PartialOrd,
1176 Ord,
1177 Hash,
1178 Serialize,
1179 Deserialize
1180)]
1181pub enum JoinKind {
1182 Inner,
1183 LeftOuter,
1184 RightOuter,
1185 FullOuter,
1186}
1187
1188impl fmt::Display for JoinKind {
1189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1190 write!(
1191 f,
1192 "{}",
1193 match self {
1194 JoinKind::Inner => "Inner",
1195 JoinKind::LeftOuter => "LeftOuter",
1196 JoinKind::RightOuter => "RightOuter",
1197 JoinKind::FullOuter => "FullOuter",
1198 }
1199 )
1200 }
1201}
1202
1203impl JoinKind {
1204 pub fn can_be_correlated(&self) -> bool {
1205 match self {
1206 JoinKind::Inner | JoinKind::LeftOuter => true,
1207 JoinKind::RightOuter | JoinKind::FullOuter => false,
1208 }
1209 }
1210}
1211
1212#[derive(
1213 Debug,
1214 Clone,
1215 PartialEq,
1216 Eq,
1217 PartialOrd,
1218 Ord,
1219 Hash,
1220 Serialize,
1221 Deserialize
1222)]
1223pub struct AggregateExpr {
1224 pub func: AggregateFunc,
1225 pub expr: Box<HirScalarExpr>,
1226 pub distinct: bool,
1227}
1228
1229#[derive(
1237 Clone,
1238 Debug,
1239 Eq,
1240 PartialEq,
1241 PartialOrd,
1242 Ord,
1243 Hash,
1244 Serialize,
1245 Deserialize
1246)]
1247pub enum AggregateFunc {
1248 MaxNumeric,
1249 MaxInt16,
1250 MaxInt32,
1251 MaxInt64,
1252 MaxUInt16,
1253 MaxUInt32,
1254 MaxUInt64,
1255 MaxMzTimestamp,
1256 MaxFloat32,
1257 MaxFloat64,
1258 MaxBool,
1259 MaxString,
1260 MaxDate,
1261 MaxTimestamp,
1262 MaxTimestampTz,
1263 MaxInterval,
1264 MaxTime,
1265 MinNumeric,
1266 MinInt16,
1267 MinInt32,
1268 MinInt64,
1269 MinUInt16,
1270 MinUInt32,
1271 MinUInt64,
1272 MinMzTimestamp,
1273 MinFloat32,
1274 MinFloat64,
1275 MinBool,
1276 MinString,
1277 MinDate,
1278 MinTimestamp,
1279 MinTimestampTz,
1280 MinInterval,
1281 MinTime,
1282 SumInt16,
1283 SumInt32,
1284 SumInt64,
1285 SumUInt16,
1286 SumUInt32,
1287 SumUInt64,
1288 SumFloat32,
1289 SumFloat64,
1290 SumNumeric,
1291 Count,
1292 Any,
1293 All,
1294 JsonbAgg {
1301 order_by: Vec<ColumnOrder>,
1302 },
1303 JsonbObjectAgg {
1306 order_by: Vec<ColumnOrder>,
1307 },
1308 MapAgg {
1312 order_by: Vec<ColumnOrder>,
1313 value_type: SqlScalarType,
1314 },
1315 ArrayConcat {
1318 order_by: Vec<ColumnOrder>,
1319 },
1320 ListConcat {
1323 order_by: Vec<ColumnOrder>,
1324 },
1325 StringAgg {
1326 order_by: Vec<ColumnOrder>,
1327 },
1328 FusedWindowAgg {
1334 funcs: Vec<AggregateFunc>,
1335 },
1336 Dummy,
1341}
1342
1343impl AggregateFunc {
1344 pub fn into_expr(self) -> mz_expr::AggregateFunc {
1346 match self {
1347 AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1348 AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1349 AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1350 AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1351 AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1352 AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1353 AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1354 AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1355 AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1356 AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1357 AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1358 AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1359 AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1360 AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1361 AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1362 AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1363 AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1364 AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1365 AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1366 AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1367 AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1368 AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1369 AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1370 AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1371 AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1372 AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1373 AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1374 AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1375 AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1376 AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1377 AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1378 AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1379 AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1380 AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1381 AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1382 AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1383 AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1384 AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1385 AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1386 AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1387 AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1388 AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1389 AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1390 AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1391 AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1392 AggregateFunc::All => mz_expr::AggregateFunc::All,
1393 AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1394 AggregateFunc::JsonbObjectAgg { order_by } => {
1395 mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1396 }
1397 AggregateFunc::MapAgg {
1398 order_by,
1399 value_type,
1400 } => mz_expr::AggregateFunc::MapAgg {
1401 order_by,
1402 value_type,
1403 },
1404 AggregateFunc::ArrayConcat { order_by } => {
1405 mz_expr::AggregateFunc::ArrayConcat { order_by }
1406 }
1407 AggregateFunc::ListConcat { order_by } => {
1408 mz_expr::AggregateFunc::ListConcat { order_by }
1409 }
1410 AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1411 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1414 panic!("into_expr called on FusedWindowAgg")
1415 }
1416 AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1417 }
1418 }
1419
1420 pub fn identity_datum(&self) -> Datum<'static> {
1427 match self {
1428 AggregateFunc::Any => Datum::False,
1429 AggregateFunc::All => Datum::True,
1430 AggregateFunc::Dummy => Datum::Dummy,
1431 AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1432 AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1433 AggregateFunc::MaxNumeric
1434 | AggregateFunc::MaxInt16
1435 | AggregateFunc::MaxInt32
1436 | AggregateFunc::MaxInt64
1437 | AggregateFunc::MaxUInt16
1438 | AggregateFunc::MaxUInt32
1439 | AggregateFunc::MaxUInt64
1440 | AggregateFunc::MaxMzTimestamp
1441 | AggregateFunc::MaxFloat32
1442 | AggregateFunc::MaxFloat64
1443 | AggregateFunc::MaxBool
1444 | AggregateFunc::MaxString
1445 | AggregateFunc::MaxDate
1446 | AggregateFunc::MaxTimestamp
1447 | AggregateFunc::MaxTimestampTz
1448 | AggregateFunc::MaxInterval
1449 | AggregateFunc::MaxTime
1450 | AggregateFunc::MinNumeric
1451 | AggregateFunc::MinInt16
1452 | AggregateFunc::MinInt32
1453 | AggregateFunc::MinInt64
1454 | AggregateFunc::MinUInt16
1455 | AggregateFunc::MinUInt32
1456 | AggregateFunc::MinUInt64
1457 | AggregateFunc::MinMzTimestamp
1458 | AggregateFunc::MinFloat32
1459 | AggregateFunc::MinFloat64
1460 | AggregateFunc::MinBool
1461 | AggregateFunc::MinString
1462 | AggregateFunc::MinDate
1463 | AggregateFunc::MinTimestamp
1464 | AggregateFunc::MinTimestampTz
1465 | AggregateFunc::MinInterval
1466 | AggregateFunc::MinTime
1467 | AggregateFunc::SumInt16
1468 | AggregateFunc::SumInt32
1469 | AggregateFunc::SumInt64
1470 | AggregateFunc::SumUInt16
1471 | AggregateFunc::SumUInt32
1472 | AggregateFunc::SumUInt64
1473 | AggregateFunc::SumFloat32
1474 | AggregateFunc::SumFloat64
1475 | AggregateFunc::SumNumeric
1476 | AggregateFunc::Count
1477 | AggregateFunc::JsonbAgg { .. }
1478 | AggregateFunc::JsonbObjectAgg { .. }
1479 | AggregateFunc::MapAgg { .. }
1480 | AggregateFunc::StringAgg { .. } => Datum::Null,
1481 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1482 panic!("FusedWindowAgg doesn't have an identity_datum")
1492 }
1493 }
1494 }
1495
1496 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1502 let scalar_type = match self {
1503 AggregateFunc::Count => SqlScalarType::Int64,
1504 AggregateFunc::Any => SqlScalarType::Bool,
1505 AggregateFunc::All => SqlScalarType::Bool,
1506 AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1507 AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1508 AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1509 AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1510 AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1511 max_scale: Some(NumericMaxScale::ZERO),
1512 },
1513 AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1514 AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1515 max_scale: Some(NumericMaxScale::ZERO),
1516 },
1517 AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1518 value_type: Box::new(value_type.clone()),
1519 custom_id: None,
1520 },
1521 AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1522 match input_type.scalar_type {
1523 SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1525 _ => unreachable!(),
1526 }
1527 }
1528 AggregateFunc::MaxNumeric
1529 | AggregateFunc::MaxInt16
1530 | AggregateFunc::MaxInt32
1531 | AggregateFunc::MaxInt64
1532 | AggregateFunc::MaxUInt16
1533 | AggregateFunc::MaxUInt32
1534 | AggregateFunc::MaxUInt64
1535 | AggregateFunc::MaxMzTimestamp
1536 | AggregateFunc::MaxFloat32
1537 | AggregateFunc::MaxFloat64
1538 | AggregateFunc::MaxBool
1539 | AggregateFunc::MaxString
1540 | AggregateFunc::MaxDate
1541 | AggregateFunc::MaxTimestamp
1542 | AggregateFunc::MaxTimestampTz
1543 | AggregateFunc::MaxInterval
1544 | AggregateFunc::MaxTime
1545 | AggregateFunc::MinNumeric
1546 | AggregateFunc::MinInt16
1547 | AggregateFunc::MinInt32
1548 | AggregateFunc::MinInt64
1549 | AggregateFunc::MinUInt16
1550 | AggregateFunc::MinUInt32
1551 | AggregateFunc::MinUInt64
1552 | AggregateFunc::MinMzTimestamp
1553 | AggregateFunc::MinFloat32
1554 | AggregateFunc::MinFloat64
1555 | AggregateFunc::MinBool
1556 | AggregateFunc::MinString
1557 | AggregateFunc::MinDate
1558 | AggregateFunc::MinTimestamp
1559 | AggregateFunc::MinTimestampTz
1560 | AggregateFunc::MinInterval
1561 | AggregateFunc::MinTime
1562 | AggregateFunc::SumFloat32
1563 | AggregateFunc::SumFloat64
1564 | AggregateFunc::SumNumeric
1565 | AggregateFunc::Dummy => input_type.scalar_type,
1566 AggregateFunc::FusedWindowAgg { funcs } => {
1567 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1568 SqlScalarType::Record {
1569 fields: funcs
1570 .iter()
1571 .zip_eq(input_types)
1572 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1573 .collect(),
1574 custom_id: None,
1575 }
1576 }
1577 };
1578 let nullable = !matches!(self, AggregateFunc::Count);
1580 scalar_type.nullable(nullable)
1581 }
1582
1583 pub fn is_order_sensitive(&self) -> bool {
1584 use AggregateFunc::*;
1585 matches!(
1586 self,
1587 JsonbAgg { .. }
1588 | JsonbObjectAgg { .. }
1589 | MapAgg { .. }
1590 | ArrayConcat { .. }
1591 | ListConcat { .. }
1592 | StringAgg { .. }
1593 )
1594 }
1595}
1596
1597impl HirRelationExpr {
1598 pub fn top_level_typ(&self) -> SqlRelationType {
1600 self.typ(&[], &BTreeMap::new())
1601 }
1602
1603 pub fn typ(
1608 &self,
1609 outers: &[SqlRelationType],
1610 params: &BTreeMap<usize, SqlScalarType>,
1611 ) -> SqlRelationType {
1612 stack::maybe_grow(|| match self {
1613 HirRelationExpr::Constant { typ, .. } => typ.clone(),
1614 HirRelationExpr::Get { typ, .. } => typ.clone(),
1615 HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1616 HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1617 HirRelationExpr::Project { input, outputs } => {
1618 let input_typ = input.typ(outers, params);
1619 SqlRelationType::new(
1620 outputs
1621 .iter()
1622 .map(|&i| input_typ.column_types[i].clone())
1623 .collect(),
1624 )
1625 }
1626 HirRelationExpr::Map { input, scalars } => {
1627 let mut typ = input.typ(outers, params);
1628 for scalar in scalars {
1629 typ.column_types.push(scalar.typ(outers, &typ, params));
1630 }
1631 typ
1632 }
1633 HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1634 HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1635 input.typ(outers, params)
1636 }
1637 HirRelationExpr::Join {
1638 left, right, kind, ..
1639 } => {
1640 let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1641 let right_nullable =
1642 matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1643 let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1644 let nullable = t.nullable || left_nullable;
1645 t.nullable(nullable)
1646 });
1647 let mut outers = outers.to_vec();
1648 outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1649 let rt = right
1650 .typ(&outers, params)
1651 .column_types
1652 .into_iter()
1653 .map(|t| {
1654 let nullable = t.nullable || right_nullable;
1655 t.nullable(nullable)
1656 });
1657 SqlRelationType::new(lt.chain(rt).collect())
1658 }
1659 HirRelationExpr::Reduce {
1660 input,
1661 group_key,
1662 aggregates,
1663 expected_group_size: _,
1664 } => {
1665 let input_typ = input.typ(outers, params);
1666 let mut column_types = group_key
1667 .iter()
1668 .map(|&i| input_typ.column_types[i].clone())
1669 .collect::<Vec<_>>();
1670 for agg in aggregates {
1671 column_types.push(agg.typ(outers, &input_typ, params));
1672 }
1673 SqlRelationType::new(column_types)
1675 }
1676 HirRelationExpr::Distinct { input }
1678 | HirRelationExpr::Negate { input }
1679 | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1680 HirRelationExpr::Union { base, inputs } => {
1681 let mut base_cols = base.typ(outers, params).column_types;
1682 for input in inputs {
1683 for (base_col, col) in base_cols
1684 .iter_mut()
1685 .zip_eq(input.typ(outers, params).column_types)
1686 {
1687 *base_col = base_col.sql_union(&col).unwrap(); }
1689 }
1690 SqlRelationType::new(base_cols)
1691 }
1692 })
1693 }
1694
1695 pub fn arity(&self) -> usize {
1696 match self {
1697 HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1698 HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1699 HirRelationExpr::Let { body, .. } => body.arity(),
1700 HirRelationExpr::LetRec { body, .. } => body.arity(),
1701 HirRelationExpr::Project { outputs, .. } => outputs.len(),
1702 HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1703 HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1704 HirRelationExpr::Filter { input, .. }
1705 | HirRelationExpr::TopK { input, .. }
1706 | HirRelationExpr::Distinct { input }
1707 | HirRelationExpr::Negate { input }
1708 | HirRelationExpr::Threshold { input } => input.arity(),
1709 HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1710 HirRelationExpr::Union { base, .. } => base.arity(),
1711 HirRelationExpr::Reduce {
1712 group_key,
1713 aggregates,
1714 ..
1715 } => group_key.len() + aggregates.len(),
1716 }
1717 }
1718
1719 pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1721 match self {
1722 Self::Constant { rows, typ } => Some((rows, typ)),
1723 _ => None,
1724 }
1725 }
1726
1727 pub fn is_correlated(&self) -> bool {
1730 let mut correlated = false;
1731 #[allow(deprecated)]
1732 self.visit_columns(0, &mut |depth, col| {
1733 if col.level > depth && col.level - depth == 1 {
1734 correlated = true;
1735 }
1736 });
1737 correlated
1738 }
1739
1740 pub fn is_join_identity(&self) -> bool {
1741 match self {
1742 HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1743 _ => false,
1744 }
1745 }
1746
1747 pub fn project(self, outputs: Vec<usize>) -> Self {
1748 if outputs.iter().copied().eq(0..self.arity()) {
1749 self
1751 } else {
1752 HirRelationExpr::Project {
1753 input: Box::new(self),
1754 outputs,
1755 }
1756 }
1757 }
1758
1759 pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1760 if scalars.is_empty() {
1761 self
1763 } else if let HirRelationExpr::Map {
1764 scalars: old_scalars,
1765 input: _,
1766 } = &mut self
1767 {
1768 old_scalars.extend(scalars);
1770 self
1771 } else {
1772 HirRelationExpr::Map {
1773 input: Box::new(self),
1774 scalars,
1775 }
1776 }
1777 }
1778
1779 pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1780 if let HirRelationExpr::Filter {
1781 input: _,
1782 predicates,
1783 } = &mut self
1784 {
1785 predicates.extend(preds);
1786 predicates.sort();
1787 predicates.dedup();
1788 self
1789 } else {
1790 preds.sort();
1791 preds.dedup();
1792 HirRelationExpr::Filter {
1793 input: Box::new(self),
1794 predicates: preds,
1795 }
1796 }
1797 }
1798
1799 pub fn reduce(
1800 self,
1801 group_key: Vec<usize>,
1802 aggregates: Vec<AggregateExpr>,
1803 expected_group_size: Option<u64>,
1804 ) -> Self {
1805 HirRelationExpr::Reduce {
1806 input: Box::new(self),
1807 group_key,
1808 aggregates,
1809 expected_group_size,
1810 }
1811 }
1812
1813 pub fn top_k(
1814 self,
1815 group_key: Vec<usize>,
1816 order_key: Vec<ColumnOrder>,
1817 limit: Option<HirScalarExpr>,
1818 offset: HirScalarExpr,
1819 expected_group_size: Option<u64>,
1820 ) -> Self {
1821 HirRelationExpr::TopK {
1822 input: Box::new(self),
1823 group_key,
1824 order_key,
1825 limit,
1826 offset,
1827 expected_group_size,
1828 }
1829 }
1830
1831 pub fn negate(self) -> Self {
1832 if let HirRelationExpr::Negate { input } = self {
1833 *input
1834 } else {
1835 HirRelationExpr::Negate {
1836 input: Box::new(self),
1837 }
1838 }
1839 }
1840
1841 pub fn distinct(self) -> Self {
1842 if let HirRelationExpr::Distinct { .. } = self {
1843 self
1844 } else {
1845 HirRelationExpr::Distinct {
1846 input: Box::new(self),
1847 }
1848 }
1849 }
1850
1851 pub fn threshold(self) -> Self {
1852 if let HirRelationExpr::Threshold { .. } = self {
1853 self
1854 } else {
1855 HirRelationExpr::Threshold {
1856 input: Box::new(self),
1857 }
1858 }
1859 }
1860
1861 pub fn union(self, other: Self) -> Self {
1862 let mut terms = Vec::new();
1863 if let HirRelationExpr::Union { base, inputs } = self {
1864 terms.push(*base);
1865 terms.extend(inputs);
1866 } else {
1867 terms.push(self);
1868 }
1869 if let HirRelationExpr::Union { base, inputs } = other {
1870 terms.push(*base);
1871 terms.extend(inputs);
1872 } else {
1873 terms.push(other);
1874 }
1875 HirRelationExpr::Union {
1876 base: Box::new(terms.remove(0)),
1877 inputs: terms,
1878 }
1879 }
1880
1881 pub fn exists(self) -> HirScalarExpr {
1882 HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1883 }
1884
1885 pub fn select(self) -> HirScalarExpr {
1886 HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1887 }
1888
1889 pub fn join(
1890 self,
1891 mut right: HirRelationExpr,
1892 on: HirScalarExpr,
1893 kind: JoinKind,
1894 ) -> HirRelationExpr {
1895 if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1896 {
1897 #[allow(deprecated)]
1901 right.visit_columns_mut(0, &mut |depth, col| {
1902 if col.level > depth {
1903 col.level -= 1;
1904 }
1905 });
1906 right
1907 } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1908 self
1909 } else {
1910 HirRelationExpr::Join {
1911 left: Box::new(self),
1912 right: Box::new(right),
1913 on,
1914 kind,
1915 }
1916 }
1917 }
1918
1919 pub fn take(&mut self) -> HirRelationExpr {
1920 mem::replace(
1921 self,
1922 HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1923 )
1924 }
1925
1926 #[deprecated = "Use `Visit::visit_post`."]
1927 pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1928 where
1929 F: FnMut(&'a Self, usize),
1930 {
1931 #[allow(deprecated)]
1932 let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1933 depth: usize|
1934 -> Result<(), ()> {
1935 f(e, depth);
1936 Ok(())
1937 });
1938 }
1939
1940 #[deprecated = "Use `Visit::try_visit_post`."]
1941 pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1942 where
1943 F: FnMut(&'a Self, usize) -> Result<(), E>,
1944 {
1945 #[allow(deprecated)]
1946 self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1947 e.visit_fallible(depth, f)
1948 })?;
1949 f(self, depth)
1950 }
1951
1952 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1953 pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1954 where
1955 F: FnMut(&'a Self, usize) -> Result<(), E>,
1956 {
1957 match self {
1958 HirRelationExpr::Constant { .. }
1959 | HirRelationExpr::Get { .. }
1960 | HirRelationExpr::CallTable { .. } => (),
1961 HirRelationExpr::Let { body, value, .. } => {
1962 f(value, depth)?;
1963 f(body, depth)?;
1964 }
1965 HirRelationExpr::LetRec {
1966 limit: _,
1967 bindings,
1968 body,
1969 } => {
1970 for (_, _, value, _) in bindings.iter() {
1971 f(value, depth)?;
1972 }
1973 f(body, depth)?;
1974 }
1975 HirRelationExpr::Project { input, .. } => {
1976 f(input, depth)?;
1977 }
1978 HirRelationExpr::Map { input, .. } => {
1979 f(input, depth)?;
1980 }
1981 HirRelationExpr::Filter { input, .. } => {
1982 f(input, depth)?;
1983 }
1984 HirRelationExpr::Join { left, right, .. } => {
1985 f(left, depth)?;
1986 f(right, depth + 1)?;
1987 }
1988 HirRelationExpr::Reduce { input, .. } => {
1989 f(input, depth)?;
1990 }
1991 HirRelationExpr::Distinct { input } => {
1992 f(input, depth)?;
1993 }
1994 HirRelationExpr::TopK { input, .. } => {
1995 f(input, depth)?;
1996 }
1997 HirRelationExpr::Negate { input } => {
1998 f(input, depth)?;
1999 }
2000 HirRelationExpr::Threshold { input } => {
2001 f(input, depth)?;
2002 }
2003 HirRelationExpr::Union { base, inputs } => {
2004 f(base, depth)?;
2005 for input in inputs {
2006 f(input, depth)?;
2007 }
2008 }
2009 }
2010 Ok(())
2011 }
2012
2013 #[deprecated = "Use `Visit::visit_mut_post` instead."]
2014 pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
2015 where
2016 F: FnMut(&mut Self, usize),
2017 {
2018 #[allow(deprecated)]
2019 let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2020 depth: usize|
2021 -> Result<(), ()> {
2022 f(e, depth);
2023 Ok(())
2024 });
2025 }
2026
2027 #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
2028 pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2029 where
2030 F: FnMut(&mut Self, usize) -> Result<(), E>,
2031 {
2032 #[allow(deprecated)]
2033 self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
2034 e.visit_mut_fallible(depth, f)
2035 })?;
2036 f(self, depth)
2037 }
2038
2039 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
2040 pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
2041 where
2042 F: FnMut(&'a mut Self, usize) -> Result<(), E>,
2043 {
2044 match self {
2045 HirRelationExpr::Constant { .. }
2046 | HirRelationExpr::Get { .. }
2047 | HirRelationExpr::CallTable { .. } => (),
2048 HirRelationExpr::Let { body, value, .. } => {
2049 f(value, depth)?;
2050 f(body, depth)?;
2051 }
2052 HirRelationExpr::LetRec {
2053 limit: _,
2054 bindings,
2055 body,
2056 } => {
2057 for (_, _, value, _) in bindings.iter_mut() {
2058 f(value, depth)?;
2059 }
2060 f(body, depth)?;
2061 }
2062 HirRelationExpr::Project { input, .. } => {
2063 f(input, depth)?;
2064 }
2065 HirRelationExpr::Map { input, .. } => {
2066 f(input, depth)?;
2067 }
2068 HirRelationExpr::Filter { input, .. } => {
2069 f(input, depth)?;
2070 }
2071 HirRelationExpr::Join { left, right, .. } => {
2072 f(left, depth)?;
2073 f(right, depth + 1)?;
2074 }
2075 HirRelationExpr::Reduce { input, .. } => {
2076 f(input, depth)?;
2077 }
2078 HirRelationExpr::Distinct { input } => {
2079 f(input, depth)?;
2080 }
2081 HirRelationExpr::TopK { input, .. } => {
2082 f(input, depth)?;
2083 }
2084 HirRelationExpr::Negate { input } => {
2085 f(input, depth)?;
2086 }
2087 HirRelationExpr::Threshold { input } => {
2088 f(input, depth)?;
2089 }
2090 HirRelationExpr::Union { base, inputs } => {
2091 f(base, depth)?;
2092 for input in inputs {
2093 f(input, depth)?;
2094 }
2095 }
2096 }
2097 Ok(())
2098 }
2099
2100 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2101 pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
2107 where
2108 F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
2109 {
2110 #[allow(deprecated)]
2111 self.visit_fallible(depth, &mut |e: &HirRelationExpr,
2112 depth: usize|
2113 -> Result<(), E> {
2114 match e {
2115 HirRelationExpr::Join { on, .. } => {
2116 f(on, depth)?;
2117 }
2118 HirRelationExpr::Map { scalars, .. } => {
2119 for scalar in scalars {
2120 f(scalar, depth)?;
2121 }
2122 }
2123 HirRelationExpr::CallTable { exprs, .. } => {
2124 for expr in exprs {
2125 f(expr, depth)?;
2126 }
2127 }
2128 HirRelationExpr::Filter { predicates, .. } => {
2129 for predicate in predicates {
2130 f(predicate, depth)?;
2131 }
2132 }
2133 HirRelationExpr::Reduce { aggregates, .. } => {
2134 for aggregate in aggregates {
2135 f(&aggregate.expr, depth)?;
2136 }
2137 }
2138 HirRelationExpr::TopK { limit, offset, .. } => {
2139 if let Some(limit) = limit {
2140 f(limit, depth)?;
2141 }
2142 f(offset, depth)?;
2143 }
2144 HirRelationExpr::Union { .. }
2145 | HirRelationExpr::Let { .. }
2146 | HirRelationExpr::LetRec { .. }
2147 | HirRelationExpr::Project { .. }
2148 | HirRelationExpr::Distinct { .. }
2149 | HirRelationExpr::Negate { .. }
2150 | HirRelationExpr::Threshold { .. }
2151 | HirRelationExpr::Constant { .. }
2152 | HirRelationExpr::Get { .. } => (),
2153 }
2154 Ok(())
2155 })
2156 }
2157
2158 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2159 pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2161 where
2162 F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2163 {
2164 #[allow(deprecated)]
2165 self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2166 depth: usize|
2167 -> Result<(), E> {
2168 match e {
2169 HirRelationExpr::Join { on, .. } => {
2170 f(on, depth)?;
2171 }
2172 HirRelationExpr::Map { scalars, .. } => {
2173 for scalar in scalars.iter_mut() {
2174 f(scalar, depth)?;
2175 }
2176 }
2177 HirRelationExpr::CallTable { exprs, .. } => {
2178 for expr in exprs.iter_mut() {
2179 f(expr, depth)?;
2180 }
2181 }
2182 HirRelationExpr::Filter { predicates, .. } => {
2183 for predicate in predicates.iter_mut() {
2184 f(predicate, depth)?;
2185 }
2186 }
2187 HirRelationExpr::Reduce { aggregates, .. } => {
2188 for aggregate in aggregates.iter_mut() {
2189 f(&mut aggregate.expr, depth)?;
2190 }
2191 }
2192 HirRelationExpr::TopK { limit, offset, .. } => {
2193 if let Some(limit) = limit {
2194 f(limit, depth)?;
2195 }
2196 f(offset, depth)?;
2197 }
2198 HirRelationExpr::Union { .. }
2199 | HirRelationExpr::Let { .. }
2200 | HirRelationExpr::LetRec { .. }
2201 | HirRelationExpr::Project { .. }
2202 | HirRelationExpr::Distinct { .. }
2203 | HirRelationExpr::Negate { .. }
2204 | HirRelationExpr::Threshold { .. }
2205 | HirRelationExpr::Constant { .. }
2206 | HirRelationExpr::Get { .. } => (),
2207 }
2208 Ok(())
2209 })
2210 }
2211
2212 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2213 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2219 where
2220 F: FnMut(usize, &ColumnRef),
2221 {
2222 #[allow(deprecated)]
2223 let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2224 depth: usize|
2225 -> Result<(), ()> {
2226 e.visit_columns(depth, f);
2227 Ok(())
2228 });
2229 }
2230
2231 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2232 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2234 where
2235 F: FnMut(usize, &mut ColumnRef),
2236 {
2237 #[allow(deprecated)]
2238 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2239 depth: usize|
2240 -> Result<(), ()> {
2241 e.visit_columns_mut(depth, f);
2242 Ok(())
2243 });
2244 }
2245
2246 pub fn bind_parameters(
2249 &mut self,
2250 scx: &StatementContext,
2251 lifetime: QueryLifetime,
2252 params: &Params,
2253 ) -> Result<(), PlanError> {
2254 #[allow(deprecated)]
2255 self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2256 e.bind_parameters(scx, lifetime, params)
2257 })
2258 }
2259
2260 pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2261 let mut contains_parameters = false;
2262 #[allow(deprecated)]
2263 self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2264 if e.contains_parameters() {
2265 contains_parameters = true;
2266 }
2267 Ok::<(), PlanError>(())
2268 })?;
2269 Ok(contains_parameters)
2270 }
2271
2272 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2274 #[allow(deprecated)]
2275 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2276 depth: usize|
2277 -> Result<(), ()> {
2278 e.splice_parameters(params, depth);
2279 Ok(())
2280 });
2281 }
2282
2283 pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2285 let rows = rows
2286 .into_iter()
2287 .map(move |datums| Row::pack_slice(&datums))
2288 .collect();
2289 HirRelationExpr::Constant { rows, typ }
2290 }
2291
2292 pub fn finish_maintained(
2298 &mut self,
2299 finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2300 group_size_hints: GroupSizeHints,
2301 ) {
2302 if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2303 let old_finishing = mem::replace(
2304 finishing,
2305 HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2306 );
2307 *self = HirRelationExpr::top_k(
2308 std::mem::replace(
2309 self,
2310 HirRelationExpr::Constant {
2311 rows: vec![],
2312 typ: SqlRelationType::new(Vec::new()),
2313 },
2314 ),
2315 vec![],
2316 old_finishing.order_by,
2317 old_finishing.limit,
2318 old_finishing.offset,
2319 group_size_hints.limit_input_group_size,
2320 )
2321 .project(old_finishing.project);
2322 }
2323 }
2324
2325 pub fn trivial_row_set_finishing_hir(
2330 arity: usize,
2331 ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2332 RowSetFinishing {
2333 order_by: Vec::new(),
2334 limit: None,
2335 offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2336 project: (0..arity).collect(),
2337 }
2338 }
2339
2340 pub fn is_trivial_row_set_finishing_hir(
2345 rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2346 arity: usize,
2347 ) -> bool {
2348 rsf.limit.is_none()
2349 && rsf.order_by.is_empty()
2350 && rsf
2351 .offset
2352 .clone()
2353 .try_into_literal_int64()
2354 .is_ok_and(|o| o == 0)
2355 && rsf.project.iter().copied().eq(0..arity)
2356 }
2357
2358 pub fn could_run_expensive_function(&self) -> bool {
2367 let mut result = false;
2368 if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2369 use HirRelationExpr::*;
2370 use HirScalarExpr::*;
2371
2372 self.visit_children(|scalar: &HirScalarExpr| {
2373 if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2374 result |= match scalar {
2375 Column(..)
2376 | Literal(..)
2377 | CallUnmaterializable(..)
2378 | If { .. }
2379 | Parameter(..)
2380 | Select(..)
2381 | Exists(..) => false,
2382 CallUnary { .. }
2384 | CallBinary { .. }
2385 | CallVariadic { .. }
2386 | Windowing(..) => true,
2387 };
2388 }) {
2389 result = true;
2391 }
2392 });
2393
2394 result |= matches!(e, CallTable { .. } | Reduce { .. });
2397 }) {
2398 result = true;
2400 }
2401
2402 result
2403 }
2404
2405 pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2407 let mut contains = false;
2408 self.visit_post(&mut |expr| {
2409 expr.visit_children(|expr: &HirScalarExpr| {
2410 contains = contains || expr.contains_temporal()
2411 })
2412 })?;
2413 Ok(contains)
2414 }
2415
2416 pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2418 let mut contains = false;
2419 self.visit_post(&mut |expr| {
2420 expr.visit_children(|expr: &HirScalarExpr| {
2421 contains = contains || expr.contains_unmaterializable()
2422 })
2423 })?;
2424 Ok(contains)
2425 }
2426}
2427
2428impl CollectionPlan for HirRelationExpr {
2429 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2436 if let Self::Get {
2437 id: Id::Global(id), ..
2438 } = self
2439 {
2440 out.insert(*id);
2441 }
2442 self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2443 }
2444}
2445
2446impl VisitChildren<Self> for HirRelationExpr {
2447 fn visit_children<F>(&self, mut f: F)
2448 where
2449 F: FnMut(&Self),
2450 {
2451 VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2455 #[allow(deprecated)]
2456 Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2457 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2458 f(expr.as_ref())
2459 }
2460 _ => (),
2461 });
2462 });
2463
2464 use HirRelationExpr::*;
2465 match self {
2466 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2467 Let {
2468 name: _,
2469 id: _,
2470 value,
2471 body,
2472 } => {
2473 f(value);
2474 f(body);
2475 }
2476 LetRec {
2477 limit: _,
2478 bindings,
2479 body,
2480 } => {
2481 for (_, _, value, _) in bindings.iter() {
2482 f(value);
2483 }
2484 f(body);
2485 }
2486 Project { input, outputs: _ } => f(input),
2487 Map { input, scalars: _ } => {
2488 f(input);
2489 }
2490 CallTable { func: _, exprs: _ } => (),
2491 Filter {
2492 input,
2493 predicates: _,
2494 } => {
2495 f(input);
2496 }
2497 Join {
2498 left,
2499 right,
2500 on: _,
2501 kind: _,
2502 } => {
2503 f(left);
2504 f(right);
2505 }
2506 Reduce {
2507 input,
2508 group_key: _,
2509 aggregates: _,
2510 expected_group_size: _,
2511 } => {
2512 f(input);
2513 }
2514 Distinct { input }
2515 | TopK {
2516 input,
2517 group_key: _,
2518 order_key: _,
2519 limit: _,
2520 offset: _,
2521 expected_group_size: _,
2522 }
2523 | Negate { input }
2524 | Threshold { input } => {
2525 f(input);
2526 }
2527 Union { base, inputs } => {
2528 f(base);
2529 for input in inputs {
2530 f(input);
2531 }
2532 }
2533 }
2534 }
2535
2536 fn visit_mut_children<F>(&mut self, mut f: F)
2537 where
2538 F: FnMut(&mut Self),
2539 {
2540 VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2544 #[allow(deprecated)]
2545 Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2546 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2547 f(expr.as_mut())
2548 }
2549 _ => (),
2550 });
2551 });
2552
2553 use HirRelationExpr::*;
2554 match self {
2555 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2556 Let {
2557 name: _,
2558 id: _,
2559 value,
2560 body,
2561 } => {
2562 f(value);
2563 f(body);
2564 }
2565 LetRec {
2566 limit: _,
2567 bindings,
2568 body,
2569 } => {
2570 for (_, _, value, _) in bindings.iter_mut() {
2571 f(value);
2572 }
2573 f(body);
2574 }
2575 Project { input, outputs: _ } => f(input),
2576 Map { input, scalars: _ } => {
2577 f(input);
2578 }
2579 CallTable { func: _, exprs: _ } => (),
2580 Filter {
2581 input,
2582 predicates: _,
2583 } => {
2584 f(input);
2585 }
2586 Join {
2587 left,
2588 right,
2589 on: _,
2590 kind: _,
2591 } => {
2592 f(left);
2593 f(right);
2594 }
2595 Reduce {
2596 input,
2597 group_key: _,
2598 aggregates: _,
2599 expected_group_size: _,
2600 } => {
2601 f(input);
2602 }
2603 Distinct { input }
2604 | TopK {
2605 input,
2606 group_key: _,
2607 order_key: _,
2608 limit: _,
2609 offset: _,
2610 expected_group_size: _,
2611 }
2612 | Negate { input }
2613 | Threshold { input } => {
2614 f(input);
2615 }
2616 Union { base, inputs } => {
2617 f(base);
2618 for input in inputs {
2619 f(input);
2620 }
2621 }
2622 }
2623 }
2624
2625 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2626 where
2627 F: FnMut(&Self) -> Result<(), E>,
2628 E: From<RecursionLimitError>,
2629 {
2630 VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2634 Visit::try_visit_post(expr, &mut |expr| match expr {
2635 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2636 f(expr.as_ref())
2637 }
2638 _ => Ok(()),
2639 })
2640 })?;
2641
2642 use HirRelationExpr::*;
2643 match self {
2644 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2645 Let {
2646 name: _,
2647 id: _,
2648 value,
2649 body,
2650 } => {
2651 f(value)?;
2652 f(body)?;
2653 }
2654 LetRec {
2655 limit: _,
2656 bindings,
2657 body,
2658 } => {
2659 for (_, _, value, _) in bindings.iter() {
2660 f(value)?;
2661 }
2662 f(body)?;
2663 }
2664 Project { input, outputs: _ } => f(input)?,
2665 Map { input, scalars: _ } => {
2666 f(input)?;
2667 }
2668 CallTable { func: _, exprs: _ } => (),
2669 Filter {
2670 input,
2671 predicates: _,
2672 } => {
2673 f(input)?;
2674 }
2675 Join {
2676 left,
2677 right,
2678 on: _,
2679 kind: _,
2680 } => {
2681 f(left)?;
2682 f(right)?;
2683 }
2684 Reduce {
2685 input,
2686 group_key: _,
2687 aggregates: _,
2688 expected_group_size: _,
2689 } => {
2690 f(input)?;
2691 }
2692 Distinct { input }
2693 | TopK {
2694 input,
2695 group_key: _,
2696 order_key: _,
2697 limit: _,
2698 offset: _,
2699 expected_group_size: _,
2700 }
2701 | Negate { input }
2702 | Threshold { input } => {
2703 f(input)?;
2704 }
2705 Union { base, inputs } => {
2706 f(base)?;
2707 for input in inputs {
2708 f(input)?;
2709 }
2710 }
2711 }
2712 Ok(())
2713 }
2714
2715 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2716 where
2717 F: FnMut(&mut Self) -> Result<(), E>,
2718 E: From<RecursionLimitError>,
2719 {
2720 VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2724 Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2725 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2726 f(expr.as_mut())
2727 }
2728 _ => Ok(()),
2729 })
2730 })?;
2731
2732 use HirRelationExpr::*;
2733 match self {
2734 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2735 Let {
2736 name: _,
2737 id: _,
2738 value,
2739 body,
2740 } => {
2741 f(value)?;
2742 f(body)?;
2743 }
2744 LetRec {
2745 limit: _,
2746 bindings,
2747 body,
2748 } => {
2749 for (_, _, value, _) in bindings.iter_mut() {
2750 f(value)?;
2751 }
2752 f(body)?;
2753 }
2754 Project { input, outputs: _ } => f(input)?,
2755 Map { input, scalars: _ } => {
2756 f(input)?;
2757 }
2758 CallTable { func: _, exprs: _ } => (),
2759 Filter {
2760 input,
2761 predicates: _,
2762 } => {
2763 f(input)?;
2764 }
2765 Join {
2766 left,
2767 right,
2768 on: _,
2769 kind: _,
2770 } => {
2771 f(left)?;
2772 f(right)?;
2773 }
2774 Reduce {
2775 input,
2776 group_key: _,
2777 aggregates: _,
2778 expected_group_size: _,
2779 } => {
2780 f(input)?;
2781 }
2782 Distinct { input }
2783 | TopK {
2784 input,
2785 group_key: _,
2786 order_key: _,
2787 limit: _,
2788 offset: _,
2789 expected_group_size: _,
2790 }
2791 | Negate { input }
2792 | Threshold { input } => {
2793 f(input)?;
2794 }
2795 Union { base, inputs } => {
2796 f(base)?;
2797 for input in inputs {
2798 f(input)?;
2799 }
2800 }
2801 }
2802 Ok(())
2803 }
2804}
2805
2806impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2807 fn visit_children<F>(&self, mut f: F)
2808 where
2809 F: FnMut(&HirScalarExpr),
2810 {
2811 use HirRelationExpr::*;
2812 match self {
2813 Constant { rows: _, typ: _ }
2814 | Get { id: _, typ: _ }
2815 | Let {
2816 name: _,
2817 id: _,
2818 value: _,
2819 body: _,
2820 }
2821 | LetRec {
2822 limit: _,
2823 bindings: _,
2824 body: _,
2825 }
2826 | Project {
2827 input: _,
2828 outputs: _,
2829 } => (),
2830 Map { input: _, scalars } => {
2831 for scalar in scalars {
2832 f(scalar);
2833 }
2834 }
2835 CallTable { func: _, exprs } => {
2836 for expr in exprs {
2837 f(expr);
2838 }
2839 }
2840 Filter {
2841 input: _,
2842 predicates,
2843 } => {
2844 for predicate in predicates {
2845 f(predicate);
2846 }
2847 }
2848 Join {
2849 left: _,
2850 right: _,
2851 on,
2852 kind: _,
2853 } => f(on),
2854 Reduce {
2855 input: _,
2856 group_key: _,
2857 aggregates,
2858 expected_group_size: _,
2859 } => {
2860 for aggregate in aggregates {
2861 f(aggregate.expr.as_ref());
2862 }
2863 }
2864 TopK {
2865 input: _,
2866 group_key: _,
2867 order_key: _,
2868 limit,
2869 offset,
2870 expected_group_size: _,
2871 } => {
2872 if let Some(limit) = limit {
2873 f(limit)
2874 }
2875 f(offset)
2876 }
2877 Distinct { input: _ }
2878 | Negate { input: _ }
2879 | Threshold { input: _ }
2880 | Union { base: _, inputs: _ } => (),
2881 }
2882 }
2883
2884 fn visit_mut_children<F>(&mut self, mut f: F)
2885 where
2886 F: FnMut(&mut HirScalarExpr),
2887 {
2888 use HirRelationExpr::*;
2889 match self {
2890 Constant { rows: _, typ: _ }
2891 | Get { id: _, typ: _ }
2892 | Let {
2893 name: _,
2894 id: _,
2895 value: _,
2896 body: _,
2897 }
2898 | LetRec {
2899 limit: _,
2900 bindings: _,
2901 body: _,
2902 }
2903 | Project {
2904 input: _,
2905 outputs: _,
2906 } => (),
2907 Map { input: _, scalars } => {
2908 for scalar in scalars {
2909 f(scalar);
2910 }
2911 }
2912 CallTable { func: _, exprs } => {
2913 for expr in exprs {
2914 f(expr);
2915 }
2916 }
2917 Filter {
2918 input: _,
2919 predicates,
2920 } => {
2921 for predicate in predicates {
2922 f(predicate);
2923 }
2924 }
2925 Join {
2926 left: _,
2927 right: _,
2928 on,
2929 kind: _,
2930 } => f(on),
2931 Reduce {
2932 input: _,
2933 group_key: _,
2934 aggregates,
2935 expected_group_size: _,
2936 } => {
2937 for aggregate in aggregates {
2938 f(aggregate.expr.as_mut());
2939 }
2940 }
2941 TopK {
2942 input: _,
2943 group_key: _,
2944 order_key: _,
2945 limit,
2946 offset,
2947 expected_group_size: _,
2948 } => {
2949 if let Some(limit) = limit {
2950 f(limit)
2951 }
2952 f(offset)
2953 }
2954 Distinct { input: _ }
2955 | Negate { input: _ }
2956 | Threshold { input: _ }
2957 | Union { base: _, inputs: _ } => (),
2958 }
2959 }
2960
2961 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2962 where
2963 F: FnMut(&HirScalarExpr) -> Result<(), E>,
2964 E: From<RecursionLimitError>,
2965 {
2966 use HirRelationExpr::*;
2967 match self {
2968 Constant { rows: _, typ: _ }
2969 | Get { id: _, typ: _ }
2970 | Let {
2971 name: _,
2972 id: _,
2973 value: _,
2974 body: _,
2975 }
2976 | LetRec {
2977 limit: _,
2978 bindings: _,
2979 body: _,
2980 }
2981 | Project {
2982 input: _,
2983 outputs: _,
2984 } => (),
2985 Map { input: _, scalars } => {
2986 for scalar in scalars {
2987 f(scalar)?;
2988 }
2989 }
2990 CallTable { func: _, exprs } => {
2991 for expr in exprs {
2992 f(expr)?;
2993 }
2994 }
2995 Filter {
2996 input: _,
2997 predicates,
2998 } => {
2999 for predicate in predicates {
3000 f(predicate)?;
3001 }
3002 }
3003 Join {
3004 left: _,
3005 right: _,
3006 on,
3007 kind: _,
3008 } => f(on)?,
3009 Reduce {
3010 input: _,
3011 group_key: _,
3012 aggregates,
3013 expected_group_size: _,
3014 } => {
3015 for aggregate in aggregates {
3016 f(aggregate.expr.as_ref())?;
3017 }
3018 }
3019 TopK {
3020 input: _,
3021 group_key: _,
3022 order_key: _,
3023 limit,
3024 offset,
3025 expected_group_size: _,
3026 } => {
3027 if let Some(limit) = limit {
3028 f(limit)?
3029 }
3030 f(offset)?
3031 }
3032 Distinct { input: _ }
3033 | Negate { input: _ }
3034 | Threshold { input: _ }
3035 | Union { base: _, inputs: _ } => (),
3036 }
3037 Ok(())
3038 }
3039
3040 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3041 where
3042 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
3043 E: From<RecursionLimitError>,
3044 {
3045 use HirRelationExpr::*;
3046 match self {
3047 Constant { rows: _, typ: _ }
3048 | Get { id: _, typ: _ }
3049 | Let {
3050 name: _,
3051 id: _,
3052 value: _,
3053 body: _,
3054 }
3055 | LetRec {
3056 limit: _,
3057 bindings: _,
3058 body: _,
3059 }
3060 | Project {
3061 input: _,
3062 outputs: _,
3063 } => (),
3064 Map { input: _, scalars } => {
3065 for scalar in scalars {
3066 f(scalar)?;
3067 }
3068 }
3069 CallTable { func: _, exprs } => {
3070 for expr in exprs {
3071 f(expr)?;
3072 }
3073 }
3074 Filter {
3075 input: _,
3076 predicates,
3077 } => {
3078 for predicate in predicates {
3079 f(predicate)?;
3080 }
3081 }
3082 Join {
3083 left: _,
3084 right: _,
3085 on,
3086 kind: _,
3087 } => f(on)?,
3088 Reduce {
3089 input: _,
3090 group_key: _,
3091 aggregates,
3092 expected_group_size: _,
3093 } => {
3094 for aggregate in aggregates {
3095 f(aggregate.expr.as_mut())?;
3096 }
3097 }
3098 TopK {
3099 input: _,
3100 group_key: _,
3101 order_key: _,
3102 limit,
3103 offset,
3104 expected_group_size: _,
3105 } => {
3106 if let Some(limit) = limit {
3107 f(limit)?
3108 }
3109 f(offset)?
3110 }
3111 Distinct { input: _ }
3112 | Negate { input: _ }
3113 | Threshold { input: _ }
3114 | Union { base: _, inputs: _ } => (),
3115 }
3116 Ok(())
3117 }
3118}
3119
3120impl HirScalarExpr {
3121 pub fn name(&self) -> Option<Arc<str>> {
3122 use HirScalarExpr::*;
3123 match self {
3124 Column(_, name)
3125 | Parameter(_, name)
3126 | Literal(_, _, name)
3127 | CallUnmaterializable(_, name)
3128 | CallUnary { name, .. }
3129 | CallBinary { name, .. }
3130 | CallVariadic { name, .. }
3131 | If { name, .. }
3132 | Exists(_, name)
3133 | Select(_, name)
3134 | Windowing(_, name) => name.0.clone(),
3135 }
3136 }
3137
3138 pub fn bind_parameters(
3141 &mut self,
3142 scx: &StatementContext,
3143 lifetime: QueryLifetime,
3144 params: &Params,
3145 ) -> Result<(), PlanError> {
3146 #[allow(deprecated)]
3147 self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3148 if let HirScalarExpr::Parameter(n, name) = e {
3149 let datum = match params.datums.iter().nth(*n - 1) {
3150 None => return Err(PlanError::UnknownParameter(*n)),
3151 Some(datum) => datum,
3152 };
3153 let scalar_type = ¶ms.execute_types[*n - 1];
3154 let row = Row::pack([datum]);
3155 let column_type = scalar_type.clone().nullable(datum.is_null());
3156
3157 let name = if let Some(name) = &name.0 {
3158 Some(Arc::clone(name))
3159 } else {
3160 Some(Arc::from(format!("${n}")))
3161 };
3162
3163 let qcx = QueryContext::root(scx, lifetime);
3164 let ecx = execute_expr_context(&qcx);
3165
3166 *e = plan_cast(
3167 &ecx,
3168 *EXECUTE_CAST_CONTEXT,
3169 HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3170 ¶ms.expected_types[*n - 1],
3171 )
3172 .expect("checked in plan_params");
3173 }
3174 Ok(())
3175 })
3176 }
3177
3178 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3189 #[allow(deprecated)]
3190 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3191 e: &mut HirScalarExpr|
3192 -> Result<(), ()> {
3193 if let HirScalarExpr::Parameter(i, _name) = e {
3194 *e = params[*i - 1].clone();
3195 e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3198 if col.level >= d {
3199 col.level += depth
3200 }
3201 });
3202 }
3203 Ok(())
3204 });
3205 }
3206
3207 pub fn contains_temporal(&self) -> bool {
3209 let mut contains = false;
3210 #[allow(deprecated)]
3211 self.visit_post_nolimit(&mut |e| {
3212 if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3213 contains = true;
3214 }
3215 });
3216 contains
3217 }
3218
3219 pub fn contains_unmaterializable(&self) -> bool {
3221 let mut contains = false;
3222 #[allow(deprecated)]
3223 self.visit_post_nolimit(&mut |e| {
3224 if let Self::CallUnmaterializable(_, _) = e {
3225 contains = true;
3226 }
3227 });
3228 contains
3229 }
3230
3231 pub fn column(index: usize) -> HirScalarExpr {
3235 HirScalarExpr::Column(
3236 ColumnRef {
3237 level: 0,
3238 column: index,
3239 },
3240 TreatAsEqual(None),
3241 )
3242 }
3243
3244 pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3246 HirScalarExpr::Column(cr, TreatAsEqual(None))
3247 }
3248
3249 pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3252 HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3253 }
3254
3255 pub fn parameter(n: usize) -> HirScalarExpr {
3256 HirScalarExpr::Parameter(n, TreatAsEqual(None))
3257 }
3258
3259 pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3260 let col_type = scalar_type.nullable(datum.is_null());
3261 soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3262 let row = Row::pack([datum]);
3263 HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3264 }
3265
3266 pub fn literal_true() -> HirScalarExpr {
3267 HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3268 }
3269
3270 pub fn literal_false() -> HirScalarExpr {
3271 HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3272 }
3273
3274 pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3275 HirScalarExpr::literal(Datum::Null, scalar_type)
3276 }
3277
3278 pub fn literal_1d_array(
3279 datums: Vec<Datum>,
3280 element_scalar_type: SqlScalarType,
3281 ) -> Result<HirScalarExpr, PlanError> {
3282 let scalar_type = match element_scalar_type {
3283 SqlScalarType::Array(_) => {
3284 sql_bail!("cannot build array from array type");
3285 }
3286 typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3287 };
3288
3289 let mut row = Row::default();
3290 row.packer()
3291 .try_push_array(
3292 &[ArrayDimension {
3293 lower_bound: 1,
3294 length: datums.len(),
3295 }],
3296 datums,
3297 )
3298 .expect("array constructed to be valid");
3299
3300 Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3301 }
3302
3303 pub fn as_literal(&self) -> Option<Datum<'_>> {
3304 if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3305 Some(row.unpack_first())
3306 } else {
3307 None
3308 }
3309 }
3310
3311 pub fn is_literal_true(&self) -> bool {
3312 Some(Datum::True) == self.as_literal()
3313 }
3314
3315 pub fn is_literal_false(&self) -> bool {
3316 Some(Datum::False) == self.as_literal()
3317 }
3318
3319 pub fn is_literal_null(&self) -> bool {
3320 Some(Datum::Null) == self.as_literal()
3321 }
3322
3323 pub fn is_constant(&self) -> bool {
3326 let mut worklist = vec![self];
3327 while let Some(expr) = worklist.pop() {
3328 match expr {
3329 Self::Literal(..) => {
3330 }
3332 Self::CallUnary { expr, .. } => {
3333 worklist.push(expr);
3334 }
3335 Self::CallBinary {
3336 func: _,
3337 expr1,
3338 expr2,
3339 name: _,
3340 } => {
3341 worklist.push(expr1);
3342 worklist.push(expr2);
3343 }
3344 Self::CallVariadic {
3345 func: _,
3346 exprs,
3347 name: _,
3348 } => {
3349 worklist.extend(exprs.iter());
3350 }
3351 Self::If {
3353 cond,
3354 then,
3355 els,
3356 name: _,
3357 } => {
3358 worklist.push(cond);
3359 worklist.push(then);
3360 worklist.push(els);
3361 }
3362 _ => {
3363 return false; }
3365 }
3366 }
3367 true
3368 }
3369
3370 pub fn call_unary(self, func: UnaryFunc) -> Self {
3371 HirScalarExpr::CallUnary {
3372 func,
3373 expr: Box::new(self),
3374 name: NameMetadata::default(),
3375 }
3376 }
3377
3378 pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3379 HirScalarExpr::CallBinary {
3380 func: func.into(),
3381 expr1: Box::new(self),
3382 expr2: Box::new(other),
3383 name: NameMetadata::default(),
3384 }
3385 }
3386
3387 pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3388 HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3389 }
3390
3391 pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3392 HirScalarExpr::CallVariadic {
3393 func,
3394 exprs,
3395 name: NameMetadata::default(),
3396 }
3397 }
3398
3399 pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3400 HirScalarExpr::If {
3401 cond: Box::new(cond),
3402 then: Box::new(then),
3403 els: Box::new(els),
3404 name: NameMetadata::default(),
3405 }
3406 }
3407
3408 pub fn windowing(expr: WindowExpr) -> Self {
3409 HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3410 }
3411
3412 pub fn or(self, other: Self) -> Self {
3413 HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3414 }
3415
3416 pub fn and(self, other: Self) -> Self {
3417 HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3418 }
3419
3420 pub fn not(self) -> Self {
3421 self.call_unary(UnaryFunc::Not(func::Not))
3422 }
3423
3424 pub fn call_is_null(self) -> Self {
3425 self.call_unary(UnaryFunc::IsNull(func::IsNull))
3426 }
3427
3428 pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3430 match args.len() {
3431 0 => HirScalarExpr::literal_true(), 1 => args.swap_remove(0),
3433 _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3434 }
3435 }
3436
3437 pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3439 match args.len() {
3440 0 => HirScalarExpr::literal_false(), 1 => args.swap_remove(0),
3442 _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3443 }
3444 }
3445
3446 pub fn take(&mut self) -> Self {
3447 mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3448 }
3449
3450 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3451 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3457 where
3458 F: FnMut(usize, &ColumnRef),
3459 {
3460 #[allow(deprecated)]
3461 let _ = self.visit_recursively(depth, &mut |depth: usize,
3462 e: &HirScalarExpr|
3463 -> Result<(), ()> {
3464 if let HirScalarExpr::Column(col, _name) = e {
3465 f(depth, col)
3466 }
3467 Ok(())
3468 });
3469 }
3470
3471 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3472 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3474 where
3475 F: FnMut(usize, &mut ColumnRef),
3476 {
3477 #[allow(deprecated)]
3478 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3479 e: &mut HirScalarExpr|
3480 -> Result<(), ()> {
3481 if let HirScalarExpr::Column(col, _name) = e {
3482 f(depth, col)
3483 }
3484 Ok(())
3485 });
3486 }
3487
3488 pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3494 where
3495 F: FnMut(usize),
3496 {
3497 #[allow(deprecated)]
3498 let _ = self.visit_recursively(0, &mut |depth: usize,
3499 e: &HirScalarExpr|
3500 -> Result<(), ()> {
3501 if let HirScalarExpr::Column(col, _name) = e {
3502 if col.level == depth {
3503 f(col.column)
3504 }
3505 }
3506 Ok(())
3507 });
3508 }
3509
3510 pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3512 where
3513 F: FnMut(&mut usize),
3514 {
3515 #[allow(deprecated)]
3516 let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3517 e: &mut HirScalarExpr|
3518 -> Result<(), ()> {
3519 if let HirScalarExpr::Column(col, _name) = e {
3520 if col.level == depth {
3521 f(&mut col.column)
3522 }
3523 }
3524 Ok(())
3525 });
3526 }
3527
3528 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3529 pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3533 where
3534 F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3535 {
3536 match self {
3537 HirScalarExpr::Literal(..)
3538 | HirScalarExpr::Parameter(..)
3539 | HirScalarExpr::CallUnmaterializable(..)
3540 | HirScalarExpr::Column(..) => (),
3541 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3542 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3543 expr1.visit_recursively(depth, f)?;
3544 expr2.visit_recursively(depth, f)?;
3545 }
3546 HirScalarExpr::CallVariadic { exprs, .. } => {
3547 for expr in exprs {
3548 expr.visit_recursively(depth, f)?;
3549 }
3550 }
3551 HirScalarExpr::If {
3552 cond,
3553 then,
3554 els,
3555 name: _,
3556 } => {
3557 cond.visit_recursively(depth, f)?;
3558 then.visit_recursively(depth, f)?;
3559 els.visit_recursively(depth, f)?;
3560 }
3561 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3562 #[allow(deprecated)]
3563 expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3564 e.visit_recursively(depth, f)
3565 })?;
3566 }
3567 HirScalarExpr::Windowing(expr, _name) => {
3568 expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3569 }
3570 }
3571 f(depth, self)
3572 }
3573
3574 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3575 pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3577 where
3578 F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3579 {
3580 match self {
3581 HirScalarExpr::Literal(..)
3582 | HirScalarExpr::Parameter(..)
3583 | HirScalarExpr::CallUnmaterializable(..)
3584 | HirScalarExpr::Column(..) => (),
3585 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3586 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3587 expr1.visit_recursively_mut(depth, f)?;
3588 expr2.visit_recursively_mut(depth, f)?;
3589 }
3590 HirScalarExpr::CallVariadic { exprs, .. } => {
3591 for expr in exprs {
3592 expr.visit_recursively_mut(depth, f)?;
3593 }
3594 }
3595 HirScalarExpr::If {
3596 cond,
3597 then,
3598 els,
3599 name: _,
3600 } => {
3601 cond.visit_recursively_mut(depth, f)?;
3602 then.visit_recursively_mut(depth, f)?;
3603 els.visit_recursively_mut(depth, f)?;
3604 }
3605 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3606 #[allow(deprecated)]
3607 expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3608 e.visit_recursively_mut(depth, f)
3609 })?;
3610 }
3611 HirScalarExpr::Windowing(expr, _name) => {
3612 expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3613 }
3614 }
3615 f(depth, self)
3616 }
3617
3618 fn simplify_to_literal(self) -> Option<Row> {
3627 let mut expr = self
3628 .lower_uncorrelated(crate::plan::lowering::Config::default())
3629 .ok()?;
3630 expr.reduce(&[]);
3631 match expr {
3632 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3633 _ => None,
3634 }
3635 }
3636
3637 fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3650 let mut expr = self
3651 .lower_uncorrelated(crate::plan::lowering::Config::default())
3652 .map_err(|err| {
3653 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3654 })?;
3655 expr.reduce(&[]);
3656 match expr {
3657 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3658 mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3659 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3660 ),
3661 _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3662 "Not a constant".to_string(),
3663 )),
3664 }
3665 }
3666
3667 pub fn into_literal_int64(self) -> Option<i64> {
3676 self.simplify_to_literal().and_then(|row| {
3677 let datum = row.unpack_first();
3678 if datum.is_null() {
3679 None
3680 } else {
3681 Some(datum.unwrap_int64())
3682 }
3683 })
3684 }
3685
3686 pub fn into_literal_string(self) -> Option<String> {
3695 self.simplify_to_literal().and_then(|row| {
3696 let datum = row.unpack_first();
3697 if datum.is_null() {
3698 None
3699 } else {
3700 Some(datum.unwrap_str().to_owned())
3701 }
3702 })
3703 }
3704
3705 pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3718 self.simplify_to_literal().and_then(|row| {
3719 let datum = row.unpack_first();
3720 if datum.is_null() {
3721 None
3722 } else {
3723 Some(datum.unwrap_mz_timestamp())
3724 }
3725 })
3726 }
3727
3728 pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3740 if !self.is_constant() {
3746 return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3747 "Expected a constant expression, got {}",
3748 self
3749 )));
3750 }
3751 self.clone()
3752 .simplify_to_literal_with_result()
3753 .and_then(|row| {
3754 let datum = row.unpack_first();
3755 if datum.is_null() {
3756 Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3757 "Expected an expression that evaluates to a non-null value, got {}",
3758 self
3759 )))
3760 } else {
3761 Ok(datum.unwrap_int64())
3762 }
3763 })
3764 }
3765
3766 pub fn contains_parameters(&self) -> bool {
3767 let mut contains_parameters = false;
3768 #[allow(deprecated)]
3769 let _ = self.visit_recursively(0, &mut |_depth: usize,
3770 expr: &HirScalarExpr|
3771 -> Result<(), ()> {
3772 if let HirScalarExpr::Parameter(..) = expr {
3773 contains_parameters = true;
3774 }
3775 Ok(())
3776 });
3777 contains_parameters
3778 }
3779}
3780
3781impl VisitChildren<Self> for HirScalarExpr {
3782 fn visit_children<F>(&self, mut f: F)
3783 where
3784 F: FnMut(&Self),
3785 {
3786 use HirScalarExpr::*;
3787 match self {
3788 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3789 CallUnary { expr, .. } => f(expr),
3790 CallBinary { expr1, expr2, .. } => {
3791 f(expr1);
3792 f(expr2);
3793 }
3794 CallVariadic { exprs, .. } => {
3795 for expr in exprs {
3796 f(expr);
3797 }
3798 }
3799 If {
3800 cond,
3801 then,
3802 els,
3803 name: _,
3804 } => {
3805 f(cond);
3806 f(then);
3807 f(els);
3808 }
3809 Exists(..) | Select(..) => (),
3810 Windowing(expr, _name) => expr.visit_children(f),
3811 }
3812 }
3813
3814 fn visit_mut_children<F>(&mut self, mut f: F)
3815 where
3816 F: FnMut(&mut Self),
3817 {
3818 use HirScalarExpr::*;
3819 match self {
3820 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3821 CallUnary { expr, .. } => f(expr),
3822 CallBinary { expr1, expr2, .. } => {
3823 f(expr1);
3824 f(expr2);
3825 }
3826 CallVariadic { exprs, .. } => {
3827 for expr in exprs {
3828 f(expr);
3829 }
3830 }
3831 If {
3832 cond,
3833 then,
3834 els,
3835 name: _,
3836 } => {
3837 f(cond);
3838 f(then);
3839 f(els);
3840 }
3841 Exists(..) | Select(..) => (),
3842 Windowing(expr, _name) => expr.visit_mut_children(f),
3843 }
3844 }
3845
3846 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3847 where
3848 F: FnMut(&Self) -> Result<(), E>,
3849 E: From<RecursionLimitError>,
3850 {
3851 use HirScalarExpr::*;
3852 match self {
3853 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3854 CallUnary { expr, .. } => f(expr)?,
3855 CallBinary { expr1, expr2, .. } => {
3856 f(expr1)?;
3857 f(expr2)?;
3858 }
3859 CallVariadic { exprs, .. } => {
3860 for expr in exprs {
3861 f(expr)?;
3862 }
3863 }
3864 If {
3865 cond,
3866 then,
3867 els,
3868 name: _,
3869 } => {
3870 f(cond)?;
3871 f(then)?;
3872 f(els)?;
3873 }
3874 Exists(..) | Select(..) => (),
3875 Windowing(expr, _name) => expr.try_visit_children(f)?,
3876 }
3877 Ok(())
3878 }
3879
3880 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3881 where
3882 F: FnMut(&mut Self) -> Result<(), E>,
3883 E: From<RecursionLimitError>,
3884 {
3885 use HirScalarExpr::*;
3886 match self {
3887 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3888 CallUnary { expr, .. } => f(expr)?,
3889 CallBinary { expr1, expr2, .. } => {
3890 f(expr1)?;
3891 f(expr2)?;
3892 }
3893 CallVariadic { exprs, .. } => {
3894 for expr in exprs {
3895 f(expr)?;
3896 }
3897 }
3898 If {
3899 cond,
3900 then,
3901 els,
3902 name: _,
3903 } => {
3904 f(cond)?;
3905 f(then)?;
3906 f(els)?;
3907 }
3908 Exists(..) | Select(..) => (),
3909 Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3910 }
3911 Ok(())
3912 }
3913}
3914
3915impl AbstractExpr for HirScalarExpr {
3916 type Type = SqlColumnType;
3917
3918 fn typ(
3919 &self,
3920 outers: &[SqlRelationType],
3921 inner: &SqlRelationType,
3922 params: &BTreeMap<usize, SqlScalarType>,
3923 ) -> Self::Type {
3924 stack::maybe_grow(|| match self {
3925 HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3926 if *level == 0 {
3927 inner.column_types[*column].clone()
3928 } else {
3929 outers[*level - 1].column_types[*column].clone()
3930 }
3931 }
3932 HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3933 HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3934 HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3935 HirScalarExpr::CallUnary {
3936 expr,
3937 func,
3938 name: _,
3939 } => func.output_type(expr.typ(outers, inner, params)),
3940 HirScalarExpr::CallBinary {
3941 expr1,
3942 expr2,
3943 func,
3944 name: _,
3945 } => func.output_type(&[
3946 expr1.typ(outers, inner, params),
3947 expr2.typ(outers, inner, params),
3948 ]),
3949 HirScalarExpr::CallVariadic {
3950 exprs,
3951 func,
3952 name: _,
3953 } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3954 HirScalarExpr::If {
3955 cond: _,
3956 then,
3957 els,
3958 name: _,
3959 } => {
3960 let then_type = then.typ(outers, inner, params);
3961 let else_type = els.typ(outers, inner, params);
3962 then_type.sql_union(&else_type).unwrap() }
3964 HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3965 HirScalarExpr::Select(expr, _name) => {
3966 let mut outers = outers.to_vec();
3967 outers.insert(0, inner.clone());
3968 expr.typ(&outers, params)
3969 .column_types
3970 .into_element()
3971 .nullable(true)
3972 }
3973 HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3974 })
3975 }
3976}
3977
3978impl AggregateExpr {
3979 pub fn typ(
3980 &self,
3981 outers: &[SqlRelationType],
3982 inner: &SqlRelationType,
3983 params: &BTreeMap<usize, SqlScalarType>,
3984 ) -> SqlColumnType {
3985 self.func.output_type(self.expr.typ(outers, inner, params))
3986 }
3987
3988 pub fn is_count_asterisk(&self) -> bool {
3996 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3997 }
3998}