1use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26 BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50 type Relation = HirRelationExpr;
51 type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55 fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56 if *all {
57 let rhs = rhs.negate();
58 HirRelationExpr::union(lhs, rhs).threshold()
59 } else {
60 let lhs = lhs.distinct();
61 let rhs = rhs.distinct().negate();
62 HirRelationExpr::union(lhs, rhs).threshold()
63 }
64 }
65
66 fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67 let mut result = None;
68
69 use HirRelationExpr::*;
70 if let Threshold { input } = expr {
71 if let Union { base: lhs, inputs } = input.as_ref() {
72 if let [rhs] = &inputs[..] {
73 if let Negate { input: rhs } = rhs {
74 match (lhs.as_ref(), rhs.as_ref()) {
75 (Distinct { input: lhs }, Distinct { input: rhs }) => {
76 let all = false;
77 let lhs = lhs.as_ref();
78 let rhs = rhs.as_ref();
79 result = Some(Except { all, lhs, rhs })
80 }
81 (lhs, rhs) => {
82 let all = true;
83 result = Some(Except { all, lhs, rhs })
84 }
85 }
86 }
87 }
88 }
89 }
90
91 result
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96pub enum HirRelationExpr {
98 Constant {
99 rows: Vec<Row>,
100 typ: SqlRelationType,
101 },
102 Get {
103 id: mz_expr::Id,
104 typ: SqlRelationType,
105 },
106 LetRec {
108 limit: Option<LetRecLimit>,
110 bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
112 body: Box<HirRelationExpr>,
114 },
115 Let {
117 name: String,
118 id: mz_expr::LocalId,
120 value: Box<HirRelationExpr>,
122 body: Box<HirRelationExpr>,
124 },
125 Project {
126 input: Box<HirRelationExpr>,
127 outputs: Vec<usize>,
128 },
129 Map {
130 input: Box<HirRelationExpr>,
131 scalars: Vec<HirScalarExpr>,
132 },
133 CallTable {
134 func: TableFunc,
135 exprs: Vec<HirScalarExpr>,
136 },
137 Filter {
138 input: Box<HirRelationExpr>,
139 predicates: Vec<HirScalarExpr>,
140 },
141 Join {
144 left: Box<HirRelationExpr>,
145 right: Box<HirRelationExpr>,
146 on: HirScalarExpr,
147 kind: JoinKind,
148 },
149 Reduce {
153 input: Box<HirRelationExpr>,
154 group_key: Vec<usize>,
155 aggregates: Vec<AggregateExpr>,
156 expected_group_size: Option<u64>,
157 },
158 Distinct {
159 input: Box<HirRelationExpr>,
160 },
161 TopK {
163 input: Box<HirRelationExpr>,
165 group_key: Vec<usize>,
167 order_key: Vec<ColumnOrder>,
169 limit: Option<HirScalarExpr>,
178 offset: HirScalarExpr,
183 expected_group_size: Option<u64>,
185 },
186 Negate {
187 input: Box<HirRelationExpr>,
188 },
189 Threshold {
191 input: Box<HirRelationExpr>,
192 },
193 Union {
194 base: Box<HirRelationExpr>,
195 inputs: Vec<HirRelationExpr>,
196 },
197}
198
199pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203pub enum HirScalarExpr {
205 Column(ColumnRef, NameMetadata),
209 Parameter(usize, NameMetadata),
210 Literal(Row, SqlColumnType, NameMetadata),
211 CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212 CallUnary {
213 func: UnaryFunc,
214 expr: Box<HirScalarExpr>,
215 name: NameMetadata,
216 },
217 CallBinary {
218 func: BinaryFunc,
219 expr1: Box<HirScalarExpr>,
220 expr2: Box<HirScalarExpr>,
221 name: NameMetadata,
222 },
223 CallVariadic {
224 func: VariadicFunc,
225 exprs: Vec<HirScalarExpr>,
226 name: NameMetadata,
227 },
228 If {
229 cond: Box<HirScalarExpr>,
230 then: Box<HirScalarExpr>,
231 els: Box<HirScalarExpr>,
232 name: NameMetadata,
233 },
234 Exists(Box<HirRelationExpr>, NameMetadata),
236 Select(Box<HirRelationExpr>, NameMetadata),
241 Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245pub struct WindowExpr {
248 pub func: WindowExprType,
249 pub partition_by: Vec<HirScalarExpr>,
250 pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265 where
266 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267 {
268 #[allow(deprecated)]
269 self.func.visit_expressions(f)?;
270 for expr in self.partition_by.iter() {
271 f(expr)?;
272 }
273 for expr in self.order_by.iter() {
274 f(expr)?;
275 }
276 Ok(())
277 }
278
279 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280 where
281 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282 {
283 #[allow(deprecated)]
284 self.func.visit_expressions_mut(f)?;
285 for expr in self.partition_by.iter_mut() {
286 f(expr)?;
287 }
288 for expr in self.order_by.iter_mut() {
289 f(expr)?;
290 }
291 Ok(())
292 }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296 fn visit_children<F>(&self, mut f: F)
297 where
298 F: FnMut(&HirScalarExpr),
299 {
300 self.func.visit_children(&mut f);
301 for expr in self.partition_by.iter() {
302 f(expr);
303 }
304 for expr in self.order_by.iter() {
305 f(expr);
306 }
307 }
308
309 fn visit_mut_children<F>(&mut self, mut f: F)
310 where
311 F: FnMut(&mut HirScalarExpr),
312 {
313 self.func.visit_mut_children(&mut f);
314 for expr in self.partition_by.iter_mut() {
315 f(expr);
316 }
317 for expr in self.order_by.iter_mut() {
318 f(expr);
319 }
320 }
321
322 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323 where
324 F: FnMut(&HirScalarExpr) -> Result<(), E>,
325 E: From<RecursionLimitError>,
326 {
327 self.func.try_visit_children(&mut f)?;
328 for expr in self.partition_by.iter() {
329 f(expr)?;
330 }
331 for expr in self.order_by.iter() {
332 f(expr)?;
333 }
334 Ok(())
335 }
336
337 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338 where
339 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340 E: From<RecursionLimitError>,
341 {
342 self.func.try_visit_mut_children(&mut f)?;
343 for expr in self.partition_by.iter_mut() {
344 f(expr)?;
345 }
346 for expr in self.order_by.iter_mut() {
347 f(expr)?;
348 }
349 Ok(())
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354pub enum WindowExprType {
371 Scalar(ScalarWindowExpr),
372 Value(ValueWindowExpr),
373 Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379 where
380 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381 {
382 #[allow(deprecated)]
383 match self {
384 Self::Scalar(expr) => expr.visit_expressions(f),
385 Self::Value(expr) => expr.visit_expressions(f),
386 Self::Aggregate(expr) => expr.visit_expressions(f),
387 }
388 }
389
390 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392 where
393 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394 {
395 #[allow(deprecated)]
396 match self {
397 Self::Scalar(expr) => expr.visit_expressions_mut(f),
398 Self::Value(expr) => expr.visit_expressions_mut(f),
399 Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400 }
401 }
402
403 fn typ(
404 &self,
405 outers: &[SqlRelationType],
406 inner: &SqlRelationType,
407 params: &BTreeMap<usize, SqlScalarType>,
408 ) -> SqlColumnType {
409 match self {
410 Self::Scalar(expr) => expr.typ(outers, inner, params),
411 Self::Value(expr) => expr.typ(outers, inner, params),
412 Self::Aggregate(expr) => expr.typ(outers, inner, params),
413 }
414 }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418 fn visit_children<F>(&self, f: F)
419 where
420 F: FnMut(&HirScalarExpr),
421 {
422 match self {
423 Self::Scalar(_) => (),
424 Self::Value(expr) => expr.visit_children(f),
425 Self::Aggregate(expr) => expr.visit_children(f),
426 }
427 }
428
429 fn visit_mut_children<F>(&mut self, f: F)
430 where
431 F: FnMut(&mut HirScalarExpr),
432 {
433 match self {
434 Self::Scalar(_) => (),
435 Self::Value(expr) => expr.visit_mut_children(f),
436 Self::Aggregate(expr) => expr.visit_mut_children(f),
437 }
438 }
439
440 fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441 where
442 F: FnMut(&HirScalarExpr) -> Result<(), E>,
443 E: From<RecursionLimitError>,
444 {
445 match self {
446 Self::Scalar(_) => Ok(()),
447 Self::Value(expr) => expr.try_visit_children(f),
448 Self::Aggregate(expr) => expr.try_visit_children(f),
449 }
450 }
451
452 fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453 where
454 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455 E: From<RecursionLimitError>,
456 {
457 match self {
458 Self::Scalar(_) => Ok(()),
459 Self::Value(expr) => expr.try_visit_mut_children(f),
460 Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461 }
462 }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467 pub func: ScalarWindowFunc,
468 pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473 pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474 where
475 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476 {
477 match self.func {
478 ScalarWindowFunc::RowNumber => {}
479 ScalarWindowFunc::Rank => {}
480 ScalarWindowFunc::DenseRank => {}
481 }
482 Ok(())
483 }
484
485 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486 pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487 where
488 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489 {
490 match self.func {
491 ScalarWindowFunc::RowNumber => {}
492 ScalarWindowFunc::Rank => {}
493 ScalarWindowFunc::DenseRank => {}
494 }
495 Ok(())
496 }
497
498 fn typ(
499 &self,
500 _outers: &[SqlRelationType],
501 _inner: &SqlRelationType,
502 _params: &BTreeMap<usize, SqlScalarType>,
503 ) -> SqlColumnType {
504 self.func.output_type()
505 }
506
507 pub fn into_expr(self) -> mz_expr::AggregateFunc {
508 match self.func {
509 ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510 order_by: self.order_by,
511 },
512 ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513 order_by: self.order_by,
514 },
515 ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516 order_by: self.order_by,
517 },
518 }
519 }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523pub enum ScalarWindowFunc {
525 RowNumber,
526 Rank,
527 DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532 match self {
533 ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534 ScalarWindowFunc::Rank => write!(f, "rank"),
535 ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536 }
537 }
538}
539
540impl ScalarWindowFunc {
541 pub fn output_type(&self) -> SqlColumnType {
542 match self {
543 ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
544 ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
545 ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
546 }
547 }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552 pub func: ValueWindowFunc,
553 pub args: Box<HirScalarExpr>,
559 pub order_by: Vec<ColumnOrder>,
561 pub window_frame: WindowFrame,
562 pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567 match self {
568 ValueWindowFunc::Lag => write!(f, "lag"),
569 ValueWindowFunc::Lead => write!(f, "lead"),
570 ValueWindowFunc::FirstValue => write!(f, "first_value"),
571 ValueWindowFunc::LastValue => write!(f, "last_value"),
572 ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573 }
574 }
575}
576
577impl ValueWindowExpr {
578 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580 where
581 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582 {
583 f(&self.args)
584 }
585
586 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588 where
589 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590 {
591 f(&mut self.args)
592 }
593
594 fn typ(
595 &self,
596 outers: &[SqlRelationType],
597 inner: &SqlRelationType,
598 params: &BTreeMap<usize, SqlScalarType>,
599 ) -> SqlColumnType {
600 self.func.output_type(self.args.typ(outers, inner, params))
601 }
602
603 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605 (
606 self.args,
607 self.func
608 .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609 )
610 }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614 fn visit_children<F>(&self, mut f: F)
615 where
616 F: FnMut(&HirScalarExpr),
617 {
618 f(&self.args)
619 }
620
621 fn visit_mut_children<F>(&mut self, mut f: F)
622 where
623 F: FnMut(&mut HirScalarExpr),
624 {
625 f(&mut self.args)
626 }
627
628 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629 where
630 F: FnMut(&HirScalarExpr) -> Result<(), E>,
631 E: From<RecursionLimitError>,
632 {
633 f(&self.args)
634 }
635
636 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637 where
638 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639 E: From<RecursionLimitError>,
640 {
641 f(&mut self.args)
642 }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646pub enum ValueWindowFunc {
648 Lag,
649 Lead,
650 FirstValue,
651 LastValue,
652 Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
657 match self {
658 ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659 input_type.scalar_type.unwrap_record_element_type()[0]
661 .clone()
662 .nullable(true)
663 }
664 ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665 input_type.scalar_type.nullable(true)
666 }
667 ValueWindowFunc::Fused(funcs) => {
668 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669 SqlScalarType::Record {
670 fields: funcs
671 .iter()
672 .zip_eq(input_types)
673 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674 .collect(),
675 custom_id: None,
676 }
677 .nullable(false)
678 }
679 }
680 }
681
682 pub fn into_expr(
683 self,
684 order_by: Vec<ColumnOrder>,
685 window_frame: WindowFrame,
686 ignore_nulls: bool,
687 ) -> mz_expr::AggregateFunc {
688 match self {
689 ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691 order_by,
692 lag_lead: mz_expr::LagLeadType::Lag,
693 ignore_nulls,
694 },
695 ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696 order_by,
697 lag_lead: mz_expr::LagLeadType::Lead,
698 ignore_nulls,
699 },
700 ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701 order_by,
702 window_frame,
703 },
704 ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705 order_by,
706 window_frame,
707 },
708 ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709 funcs: funcs
710 .into_iter()
711 .map(|func| {
712 func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713 })
714 .collect(),
715 order_by,
716 },
717 }
718 }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723 pub aggregate_expr: AggregateExpr,
724 pub order_by: Vec<ColumnOrder>,
725 pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731 where
732 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733 {
734 f(&self.aggregate_expr.expr)
735 }
736
737 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739 where
740 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741 {
742 f(&mut self.aggregate_expr.expr)
743 }
744
745 fn typ(
746 &self,
747 outers: &[SqlRelationType],
748 inner: &SqlRelationType,
749 params: &BTreeMap<usize, SqlScalarType>,
750 ) -> SqlColumnType {
751 self.aggregate_expr
752 .func
753 .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754 }
755
756 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757 if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758 (
759 self.aggregate_expr.expr,
760 FusedWindowAggregate {
761 wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762 order_by: self.order_by,
763 window_frame: self.window_frame,
764 },
765 )
766 } else {
767 (
768 self.aggregate_expr.expr,
769 WindowAggregate {
770 wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771 order_by: self.order_by,
772 window_frame: self.window_frame,
773 },
774 )
775 }
776 }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780 fn visit_children<F>(&self, mut f: F)
781 where
782 F: FnMut(&HirScalarExpr),
783 {
784 f(&self.aggregate_expr.expr)
785 }
786
787 fn visit_mut_children<F>(&mut self, mut f: F)
788 where
789 F: FnMut(&mut HirScalarExpr),
790 {
791 f(&mut self.aggregate_expr.expr)
792 }
793
794 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795 where
796 F: FnMut(&HirScalarExpr) -> Result<(), E>,
797 E: From<RecursionLimitError>,
798 {
799 f(&self.aggregate_expr.expr)
800 }
801
802 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803 where
804 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805 E: From<RecursionLimitError>,
806 {
807 f(&mut self.aggregate_expr.expr)
808 }
809}
810
811#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837 Coerced(HirScalarExpr),
838 Parameter(usize),
839 LiteralNull,
840 LiteralString(String),
841 LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845 pub fn type_as(
846 self,
847 ecx: &ExprContext,
848 ty: &SqlScalarType,
849 ) -> Result<HirScalarExpr, PlanError> {
850 let expr = typeconv::plan_coerce(ecx, self, ty)?;
851 let expr_ty = ecx.scalar_type(&expr);
852 if ty != &expr_ty {
853 sql_bail!(
854 "{} must have type {}, not type {}",
855 ecx.name,
856 ecx.humanize_scalar_type(ty, false),
857 ecx.humanize_scalar_type(&expr_ty, false),
858 );
859 }
860 Ok(expr)
861 }
862
863 pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
864 typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
865 }
866
867 pub fn cast_to(
868 self,
869 ecx: &ExprContext,
870 ccx: CastContext,
871 ty: &SqlScalarType,
872 ) -> Result<HirScalarExpr, PlanError> {
873 let expr = typeconv::plan_coerce(ecx, self, ty)?;
874 typeconv::plan_cast(ecx, ccx, expr, ty)
875 }
876}
877
878#[derive(Clone, Debug)]
880pub enum CoercibleColumnType {
881 Coerced(SqlColumnType),
882 Record(Vec<CoercibleColumnType>),
883 Uncoerced,
884}
885
886impl CoercibleColumnType {
887 pub fn nullable(&self) -> bool {
889 match self {
890 CoercibleColumnType::Coerced(ct) => ct.nullable,
892
893 CoercibleColumnType::Record(_) => false,
895
896 CoercibleColumnType::Uncoerced => true,
899 }
900 }
901}
902
903#[derive(Clone, Debug)]
905pub enum CoercibleScalarType {
906 Coerced(SqlScalarType),
907 Record(Vec<CoercibleColumnType>),
908 Uncoerced,
909}
910
911impl CoercibleScalarType {
912 pub fn is_coerced(&self) -> bool {
914 matches!(self, CoercibleScalarType::Coerced(_))
915 }
916
917 pub fn as_coerced(&self) -> Option<&SqlScalarType> {
919 match self {
920 CoercibleScalarType::Coerced(t) => Some(t),
921 _ => None,
922 }
923 }
924
925 pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
928 where
929 F: FnOnce(SqlScalarType) -> SqlScalarType,
930 {
931 match self {
932 CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
933 _ => self,
934 }
935 }
936
937 pub fn force_coerced_if_record(&mut self) {
944 fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
945 let mut fields = vec![];
946 for (i, uf) in uncoerced_fields.enumerate() {
947 let name = ColumnName::from(format!("f{}", i + 1));
948 let ty = match uf {
949 CoercibleColumnType::Coerced(ty) => ty,
950 CoercibleColumnType::Record(mut fields) => {
951 convert(fields.drain(..)).nullable(false)
952 }
953 CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
954 };
955 fields.push((name, ty))
956 }
957 SqlScalarType::Record {
958 fields: fields.into(),
959 custom_id: None,
960 }
961 }
962
963 if let CoercibleScalarType::Record(fields) = self {
964 *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
965 }
966 }
967}
968
969pub trait AbstractExpr {
973 type Type: AbstractColumnType;
974
975 fn typ(
977 &self,
978 outers: &[SqlRelationType],
979 inner: &SqlRelationType,
980 params: &BTreeMap<usize, SqlScalarType>,
981 ) -> Self::Type;
982}
983
984impl AbstractExpr for CoercibleScalarExpr {
985 type Type = CoercibleColumnType;
986
987 fn typ(
988 &self,
989 outers: &[SqlRelationType],
990 inner: &SqlRelationType,
991 params: &BTreeMap<usize, SqlScalarType>,
992 ) -> Self::Type {
993 match self {
994 CoercibleScalarExpr::Coerced(expr) => {
995 CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
996 }
997 CoercibleScalarExpr::LiteralRecord(scalars) => {
998 let fields = scalars
999 .iter()
1000 .map(|s| s.typ(outers, inner, params))
1001 .collect();
1002 CoercibleColumnType::Record(fields)
1003 }
1004 _ => CoercibleColumnType::Uncoerced,
1005 }
1006 }
1007}
1008
1009pub trait AbstractColumnType {
1014 type AbstractScalarType;
1015
1016 fn scalar_type(self) -> Self::AbstractScalarType;
1019}
1020
1021impl AbstractColumnType for SqlColumnType {
1022 type AbstractScalarType = SqlScalarType;
1023
1024 fn scalar_type(self) -> Self::AbstractScalarType {
1025 self.scalar_type
1026 }
1027}
1028
1029impl AbstractColumnType for CoercibleColumnType {
1030 type AbstractScalarType = CoercibleScalarType;
1031
1032 fn scalar_type(self) -> Self::AbstractScalarType {
1033 match self {
1034 CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1035 CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1036 CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1037 }
1038 }
1039}
1040
1041impl From<HirScalarExpr> for CoercibleScalarExpr {
1042 fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1043 CoercibleScalarExpr::Coerced(expr)
1044 }
1045}
1046
1047#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1062pub struct ColumnRef {
1063 pub level: usize,
1065 pub column: usize,
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1070pub enum JoinKind {
1071 Inner,
1072 LeftOuter,
1073 RightOuter,
1074 FullOuter,
1075}
1076
1077impl fmt::Display for JoinKind {
1078 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1079 write!(
1080 f,
1081 "{}",
1082 match self {
1083 JoinKind::Inner => "Inner",
1084 JoinKind::LeftOuter => "LeftOuter",
1085 JoinKind::RightOuter => "RightOuter",
1086 JoinKind::FullOuter => "FullOuter",
1087 }
1088 )
1089 }
1090}
1091
1092impl JoinKind {
1093 pub fn can_be_correlated(&self) -> bool {
1094 match self {
1095 JoinKind::Inner | JoinKind::LeftOuter => true,
1096 JoinKind::RightOuter | JoinKind::FullOuter => false,
1097 }
1098 }
1099}
1100
1101#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1102pub struct AggregateExpr {
1103 pub func: AggregateFunc,
1104 pub expr: Box<HirScalarExpr>,
1105 pub distinct: bool,
1106}
1107
1108#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1116pub enum AggregateFunc {
1117 MaxNumeric,
1118 MaxInt16,
1119 MaxInt32,
1120 MaxInt64,
1121 MaxUInt16,
1122 MaxUInt32,
1123 MaxUInt64,
1124 MaxMzTimestamp,
1125 MaxFloat32,
1126 MaxFloat64,
1127 MaxBool,
1128 MaxString,
1129 MaxDate,
1130 MaxTimestamp,
1131 MaxTimestampTz,
1132 MaxInterval,
1133 MaxTime,
1134 MinNumeric,
1135 MinInt16,
1136 MinInt32,
1137 MinInt64,
1138 MinUInt16,
1139 MinUInt32,
1140 MinUInt64,
1141 MinMzTimestamp,
1142 MinFloat32,
1143 MinFloat64,
1144 MinBool,
1145 MinString,
1146 MinDate,
1147 MinTimestamp,
1148 MinTimestampTz,
1149 MinInterval,
1150 MinTime,
1151 SumInt16,
1152 SumInt32,
1153 SumInt64,
1154 SumUInt16,
1155 SumUInt32,
1156 SumUInt64,
1157 SumFloat32,
1158 SumFloat64,
1159 SumNumeric,
1160 Count,
1161 Any,
1162 All,
1163 JsonbAgg {
1170 order_by: Vec<ColumnOrder>,
1171 },
1172 JsonbObjectAgg {
1175 order_by: Vec<ColumnOrder>,
1176 },
1177 MapAgg {
1181 order_by: Vec<ColumnOrder>,
1182 value_type: SqlScalarType,
1183 },
1184 ArrayConcat {
1187 order_by: Vec<ColumnOrder>,
1188 },
1189 ListConcat {
1192 order_by: Vec<ColumnOrder>,
1193 },
1194 StringAgg {
1195 order_by: Vec<ColumnOrder>,
1196 },
1197 FusedWindowAgg {
1203 funcs: Vec<AggregateFunc>,
1204 },
1205 Dummy,
1210}
1211
1212impl AggregateFunc {
1213 pub fn into_expr(self) -> mz_expr::AggregateFunc {
1215 match self {
1216 AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1217 AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1218 AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1219 AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1220 AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1221 AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1222 AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1223 AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1224 AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1225 AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1226 AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1227 AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1228 AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1229 AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1230 AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1231 AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1232 AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1233 AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1234 AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1235 AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1236 AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1237 AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1238 AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1239 AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1240 AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1241 AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1242 AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1243 AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1244 AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1245 AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1246 AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1247 AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1248 AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1249 AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1250 AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1251 AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1252 AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1253 AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1254 AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1255 AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1256 AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1257 AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1258 AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1259 AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1260 AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1261 AggregateFunc::All => mz_expr::AggregateFunc::All,
1262 AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1263 AggregateFunc::JsonbObjectAgg { order_by } => {
1264 mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1265 }
1266 AggregateFunc::MapAgg {
1267 order_by,
1268 value_type,
1269 } => mz_expr::AggregateFunc::MapAgg {
1270 order_by,
1271 value_type,
1272 },
1273 AggregateFunc::ArrayConcat { order_by } => {
1274 mz_expr::AggregateFunc::ArrayConcat { order_by }
1275 }
1276 AggregateFunc::ListConcat { order_by } => {
1277 mz_expr::AggregateFunc::ListConcat { order_by }
1278 }
1279 AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1280 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1283 panic!("into_expr called on FusedWindowAgg")
1284 }
1285 AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1286 }
1287 }
1288
1289 pub fn identity_datum(&self) -> Datum<'static> {
1296 match self {
1297 AggregateFunc::Any => Datum::False,
1298 AggregateFunc::All => Datum::True,
1299 AggregateFunc::Dummy => Datum::Dummy,
1300 AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1301 AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1302 AggregateFunc::MaxNumeric
1303 | AggregateFunc::MaxInt16
1304 | AggregateFunc::MaxInt32
1305 | AggregateFunc::MaxInt64
1306 | AggregateFunc::MaxUInt16
1307 | AggregateFunc::MaxUInt32
1308 | AggregateFunc::MaxUInt64
1309 | AggregateFunc::MaxMzTimestamp
1310 | AggregateFunc::MaxFloat32
1311 | AggregateFunc::MaxFloat64
1312 | AggregateFunc::MaxBool
1313 | AggregateFunc::MaxString
1314 | AggregateFunc::MaxDate
1315 | AggregateFunc::MaxTimestamp
1316 | AggregateFunc::MaxTimestampTz
1317 | AggregateFunc::MaxInterval
1318 | AggregateFunc::MaxTime
1319 | AggregateFunc::MinNumeric
1320 | AggregateFunc::MinInt16
1321 | AggregateFunc::MinInt32
1322 | AggregateFunc::MinInt64
1323 | AggregateFunc::MinUInt16
1324 | AggregateFunc::MinUInt32
1325 | AggregateFunc::MinUInt64
1326 | AggregateFunc::MinMzTimestamp
1327 | AggregateFunc::MinFloat32
1328 | AggregateFunc::MinFloat64
1329 | AggregateFunc::MinBool
1330 | AggregateFunc::MinString
1331 | AggregateFunc::MinDate
1332 | AggregateFunc::MinTimestamp
1333 | AggregateFunc::MinTimestampTz
1334 | AggregateFunc::MinInterval
1335 | AggregateFunc::MinTime
1336 | AggregateFunc::SumInt16
1337 | AggregateFunc::SumInt32
1338 | AggregateFunc::SumInt64
1339 | AggregateFunc::SumUInt16
1340 | AggregateFunc::SumUInt32
1341 | AggregateFunc::SumUInt64
1342 | AggregateFunc::SumFloat32
1343 | AggregateFunc::SumFloat64
1344 | AggregateFunc::SumNumeric
1345 | AggregateFunc::Count
1346 | AggregateFunc::JsonbAgg { .. }
1347 | AggregateFunc::JsonbObjectAgg { .. }
1348 | AggregateFunc::MapAgg { .. }
1349 | AggregateFunc::StringAgg { .. } => Datum::Null,
1350 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1351 panic!("FusedWindowAgg doesn't have an identity_datum")
1361 }
1362 }
1363 }
1364
1365 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1371 let scalar_type = match self {
1372 AggregateFunc::Count => SqlScalarType::Int64,
1373 AggregateFunc::Any => SqlScalarType::Bool,
1374 AggregateFunc::All => SqlScalarType::Bool,
1375 AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1376 AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1377 AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1378 AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1379 AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1380 max_scale: Some(NumericMaxScale::ZERO),
1381 },
1382 AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1383 AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1384 max_scale: Some(NumericMaxScale::ZERO),
1385 },
1386 AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1387 value_type: Box::new(value_type.clone()),
1388 custom_id: None,
1389 },
1390 AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1391 match input_type.scalar_type {
1392 SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1394 _ => unreachable!(),
1395 }
1396 }
1397 AggregateFunc::MaxNumeric
1398 | AggregateFunc::MaxInt16
1399 | AggregateFunc::MaxInt32
1400 | AggregateFunc::MaxInt64
1401 | AggregateFunc::MaxUInt16
1402 | AggregateFunc::MaxUInt32
1403 | AggregateFunc::MaxUInt64
1404 | AggregateFunc::MaxMzTimestamp
1405 | AggregateFunc::MaxFloat32
1406 | AggregateFunc::MaxFloat64
1407 | AggregateFunc::MaxBool
1408 | AggregateFunc::MaxString
1409 | AggregateFunc::MaxDate
1410 | AggregateFunc::MaxTimestamp
1411 | AggregateFunc::MaxTimestampTz
1412 | AggregateFunc::MaxInterval
1413 | AggregateFunc::MaxTime
1414 | AggregateFunc::MinNumeric
1415 | AggregateFunc::MinInt16
1416 | AggregateFunc::MinInt32
1417 | AggregateFunc::MinInt64
1418 | AggregateFunc::MinUInt16
1419 | AggregateFunc::MinUInt32
1420 | AggregateFunc::MinUInt64
1421 | AggregateFunc::MinMzTimestamp
1422 | AggregateFunc::MinFloat32
1423 | AggregateFunc::MinFloat64
1424 | AggregateFunc::MinBool
1425 | AggregateFunc::MinString
1426 | AggregateFunc::MinDate
1427 | AggregateFunc::MinTimestamp
1428 | AggregateFunc::MinTimestampTz
1429 | AggregateFunc::MinInterval
1430 | AggregateFunc::MinTime
1431 | AggregateFunc::SumFloat32
1432 | AggregateFunc::SumFloat64
1433 | AggregateFunc::SumNumeric
1434 | AggregateFunc::Dummy => input_type.scalar_type,
1435 AggregateFunc::FusedWindowAgg { funcs } => {
1436 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1437 SqlScalarType::Record {
1438 fields: funcs
1439 .iter()
1440 .zip_eq(input_types)
1441 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1442 .collect(),
1443 custom_id: None,
1444 }
1445 }
1446 };
1447 let nullable = !matches!(self, AggregateFunc::Count);
1449 scalar_type.nullable(nullable)
1450 }
1451
1452 pub fn is_order_sensitive(&self) -> bool {
1453 use AggregateFunc::*;
1454 matches!(
1455 self,
1456 JsonbAgg { .. }
1457 | JsonbObjectAgg { .. }
1458 | MapAgg { .. }
1459 | ArrayConcat { .. }
1460 | ListConcat { .. }
1461 | StringAgg { .. }
1462 )
1463 }
1464}
1465
1466impl HirRelationExpr {
1467 pub fn typ(
1468 &self,
1469 outers: &[SqlRelationType],
1470 params: &BTreeMap<usize, SqlScalarType>,
1471 ) -> SqlRelationType {
1472 stack::maybe_grow(|| match self {
1473 HirRelationExpr::Constant { typ, .. } => typ.clone(),
1474 HirRelationExpr::Get { typ, .. } => typ.clone(),
1475 HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1476 HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1477 HirRelationExpr::Project { input, outputs } => {
1478 let input_typ = input.typ(outers, params);
1479 SqlRelationType::new(
1480 outputs
1481 .iter()
1482 .map(|&i| input_typ.column_types[i].clone())
1483 .collect(),
1484 )
1485 }
1486 HirRelationExpr::Map { input, scalars } => {
1487 let mut typ = input.typ(outers, params);
1488 for scalar in scalars {
1489 typ.column_types.push(scalar.typ(outers, &typ, params));
1490 }
1491 typ
1492 }
1493 HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1494 HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1495 input.typ(outers, params)
1496 }
1497 HirRelationExpr::Join {
1498 left, right, kind, ..
1499 } => {
1500 let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1501 let right_nullable =
1502 matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1503 let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1504 let nullable = t.nullable || left_nullable;
1505 t.nullable(nullable)
1506 });
1507 let mut outers = outers.to_vec();
1508 outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1509 let rt = right
1510 .typ(&outers, params)
1511 .column_types
1512 .into_iter()
1513 .map(|t| {
1514 let nullable = t.nullable || right_nullable;
1515 t.nullable(nullable)
1516 });
1517 SqlRelationType::new(lt.chain(rt).collect())
1518 }
1519 HirRelationExpr::Reduce {
1520 input,
1521 group_key,
1522 aggregates,
1523 expected_group_size: _,
1524 } => {
1525 let input_typ = input.typ(outers, params);
1526 let mut column_types = group_key
1527 .iter()
1528 .map(|&i| input_typ.column_types[i].clone())
1529 .collect::<Vec<_>>();
1530 for agg in aggregates {
1531 column_types.push(agg.typ(outers, &input_typ, params));
1532 }
1533 SqlRelationType::new(column_types)
1535 }
1536 HirRelationExpr::Distinct { input }
1538 | HirRelationExpr::Negate { input }
1539 | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1540 HirRelationExpr::Union { base, inputs } => {
1541 let mut base_cols = base.typ(outers, params).column_types;
1542 for input in inputs {
1543 for (base_col, col) in base_cols
1544 .iter_mut()
1545 .zip_eq(input.typ(outers, params).column_types)
1546 {
1547 *base_col = base_col.union(&col).unwrap();
1548 }
1549 }
1550 SqlRelationType::new(base_cols)
1551 }
1552 })
1553 }
1554
1555 pub fn arity(&self) -> usize {
1556 match self {
1557 HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1558 HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1559 HirRelationExpr::Let { body, .. } => body.arity(),
1560 HirRelationExpr::LetRec { body, .. } => body.arity(),
1561 HirRelationExpr::Project { outputs, .. } => outputs.len(),
1562 HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1563 HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1564 HirRelationExpr::Filter { input, .. }
1565 | HirRelationExpr::TopK { input, .. }
1566 | HirRelationExpr::Distinct { input }
1567 | HirRelationExpr::Negate { input }
1568 | HirRelationExpr::Threshold { input } => input.arity(),
1569 HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1570 HirRelationExpr::Union { base, .. } => base.arity(),
1571 HirRelationExpr::Reduce {
1572 group_key,
1573 aggregates,
1574 ..
1575 } => group_key.len() + aggregates.len(),
1576 }
1577 }
1578
1579 pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1581 match self {
1582 Self::Constant { rows, typ } => Some((rows, typ)),
1583 _ => None,
1584 }
1585 }
1586
1587 pub fn is_correlated(&self) -> bool {
1590 let mut correlated = false;
1591 #[allow(deprecated)]
1592 self.visit_columns(0, &mut |depth, col| {
1593 if col.level > depth && col.level - depth == 1 {
1594 correlated = true;
1595 }
1596 });
1597 correlated
1598 }
1599
1600 pub fn is_join_identity(&self) -> bool {
1601 match self {
1602 HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1603 _ => false,
1604 }
1605 }
1606
1607 pub fn project(self, outputs: Vec<usize>) -> Self {
1608 if outputs.iter().copied().eq(0..self.arity()) {
1609 self
1611 } else {
1612 HirRelationExpr::Project {
1613 input: Box::new(self),
1614 outputs,
1615 }
1616 }
1617 }
1618
1619 pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1620 if scalars.is_empty() {
1621 self
1623 } else if let HirRelationExpr::Map {
1624 scalars: old_scalars,
1625 input: _,
1626 } = &mut self
1627 {
1628 old_scalars.extend(scalars);
1630 self
1631 } else {
1632 HirRelationExpr::Map {
1633 input: Box::new(self),
1634 scalars,
1635 }
1636 }
1637 }
1638
1639 pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1640 if let HirRelationExpr::Filter {
1641 input: _,
1642 predicates,
1643 } = &mut self
1644 {
1645 predicates.extend(preds);
1646 predicates.sort();
1647 predicates.dedup();
1648 self
1649 } else {
1650 preds.sort();
1651 preds.dedup();
1652 HirRelationExpr::Filter {
1653 input: Box::new(self),
1654 predicates: preds,
1655 }
1656 }
1657 }
1658
1659 pub fn reduce(
1660 self,
1661 group_key: Vec<usize>,
1662 aggregates: Vec<AggregateExpr>,
1663 expected_group_size: Option<u64>,
1664 ) -> Self {
1665 HirRelationExpr::Reduce {
1666 input: Box::new(self),
1667 group_key,
1668 aggregates,
1669 expected_group_size,
1670 }
1671 }
1672
1673 pub fn top_k(
1674 self,
1675 group_key: Vec<usize>,
1676 order_key: Vec<ColumnOrder>,
1677 limit: Option<HirScalarExpr>,
1678 offset: HirScalarExpr,
1679 expected_group_size: Option<u64>,
1680 ) -> Self {
1681 HirRelationExpr::TopK {
1682 input: Box::new(self),
1683 group_key,
1684 order_key,
1685 limit,
1686 offset,
1687 expected_group_size,
1688 }
1689 }
1690
1691 pub fn negate(self) -> Self {
1692 if let HirRelationExpr::Negate { input } = self {
1693 *input
1694 } else {
1695 HirRelationExpr::Negate {
1696 input: Box::new(self),
1697 }
1698 }
1699 }
1700
1701 pub fn distinct(self) -> Self {
1702 if let HirRelationExpr::Distinct { .. } = self {
1703 self
1704 } else {
1705 HirRelationExpr::Distinct {
1706 input: Box::new(self),
1707 }
1708 }
1709 }
1710
1711 pub fn threshold(self) -> Self {
1712 if let HirRelationExpr::Threshold { .. } = self {
1713 self
1714 } else {
1715 HirRelationExpr::Threshold {
1716 input: Box::new(self),
1717 }
1718 }
1719 }
1720
1721 pub fn union(self, other: Self) -> Self {
1722 let mut terms = Vec::new();
1723 if let HirRelationExpr::Union { base, inputs } = self {
1724 terms.push(*base);
1725 terms.extend(inputs);
1726 } else {
1727 terms.push(self);
1728 }
1729 if let HirRelationExpr::Union { base, inputs } = other {
1730 terms.push(*base);
1731 terms.extend(inputs);
1732 } else {
1733 terms.push(other);
1734 }
1735 HirRelationExpr::Union {
1736 base: Box::new(terms.remove(0)),
1737 inputs: terms,
1738 }
1739 }
1740
1741 pub fn exists(self) -> HirScalarExpr {
1742 HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1743 }
1744
1745 pub fn select(self) -> HirScalarExpr {
1746 HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1747 }
1748
1749 pub fn join(
1750 self,
1751 mut right: HirRelationExpr,
1752 on: HirScalarExpr,
1753 kind: JoinKind,
1754 ) -> HirRelationExpr {
1755 if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1756 {
1757 #[allow(deprecated)]
1761 right.visit_columns_mut(0, &mut |depth, col| {
1762 if col.level > depth {
1763 col.level -= 1;
1764 }
1765 });
1766 right
1767 } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1768 self
1769 } else {
1770 HirRelationExpr::Join {
1771 left: Box::new(self),
1772 right: Box::new(right),
1773 on,
1774 kind,
1775 }
1776 }
1777 }
1778
1779 pub fn take(&mut self) -> HirRelationExpr {
1780 mem::replace(
1781 self,
1782 HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1783 )
1784 }
1785
1786 #[deprecated = "Use `Visit::visit_post`."]
1787 pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1788 where
1789 F: FnMut(&'a Self, usize),
1790 {
1791 #[allow(deprecated)]
1792 let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1793 depth: usize|
1794 -> Result<(), ()> {
1795 f(e, depth);
1796 Ok(())
1797 });
1798 }
1799
1800 #[deprecated = "Use `Visit::try_visit_post`."]
1801 pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1802 where
1803 F: FnMut(&'a Self, usize) -> Result<(), E>,
1804 {
1805 #[allow(deprecated)]
1806 self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1807 e.visit_fallible(depth, f)
1808 })?;
1809 f(self, depth)
1810 }
1811
1812 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1813 pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1814 where
1815 F: FnMut(&'a Self, usize) -> Result<(), E>,
1816 {
1817 match self {
1818 HirRelationExpr::Constant { .. }
1819 | HirRelationExpr::Get { .. }
1820 | HirRelationExpr::CallTable { .. } => (),
1821 HirRelationExpr::Let { body, value, .. } => {
1822 f(value, depth)?;
1823 f(body, depth)?;
1824 }
1825 HirRelationExpr::LetRec {
1826 limit: _,
1827 bindings,
1828 body,
1829 } => {
1830 for (_, _, value, _) in bindings.iter() {
1831 f(value, depth)?;
1832 }
1833 f(body, depth)?;
1834 }
1835 HirRelationExpr::Project { input, .. } => {
1836 f(input, depth)?;
1837 }
1838 HirRelationExpr::Map { input, .. } => {
1839 f(input, depth)?;
1840 }
1841 HirRelationExpr::Filter { input, .. } => {
1842 f(input, depth)?;
1843 }
1844 HirRelationExpr::Join { left, right, .. } => {
1845 f(left, depth)?;
1846 f(right, depth + 1)?;
1847 }
1848 HirRelationExpr::Reduce { input, .. } => {
1849 f(input, depth)?;
1850 }
1851 HirRelationExpr::Distinct { input } => {
1852 f(input, depth)?;
1853 }
1854 HirRelationExpr::TopK { input, .. } => {
1855 f(input, depth)?;
1856 }
1857 HirRelationExpr::Negate { input } => {
1858 f(input, depth)?;
1859 }
1860 HirRelationExpr::Threshold { input } => {
1861 f(input, depth)?;
1862 }
1863 HirRelationExpr::Union { base, inputs } => {
1864 f(base, depth)?;
1865 for input in inputs {
1866 f(input, depth)?;
1867 }
1868 }
1869 }
1870 Ok(())
1871 }
1872
1873 #[deprecated = "Use `Visit::visit_mut_post` instead."]
1874 pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1875 where
1876 F: FnMut(&mut Self, usize),
1877 {
1878 #[allow(deprecated)]
1879 let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1880 depth: usize|
1881 -> Result<(), ()> {
1882 f(e, depth);
1883 Ok(())
1884 });
1885 }
1886
1887 #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1888 pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1889 where
1890 F: FnMut(&mut Self, usize) -> Result<(), E>,
1891 {
1892 #[allow(deprecated)]
1893 self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1894 e.visit_mut_fallible(depth, f)
1895 })?;
1896 f(self, depth)
1897 }
1898
1899 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1900 pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1901 where
1902 F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1903 {
1904 match self {
1905 HirRelationExpr::Constant { .. }
1906 | HirRelationExpr::Get { .. }
1907 | HirRelationExpr::CallTable { .. } => (),
1908 HirRelationExpr::Let { body, value, .. } => {
1909 f(value, depth)?;
1910 f(body, depth)?;
1911 }
1912 HirRelationExpr::LetRec {
1913 limit: _,
1914 bindings,
1915 body,
1916 } => {
1917 for (_, _, value, _) in bindings.iter_mut() {
1918 f(value, depth)?;
1919 }
1920 f(body, depth)?;
1921 }
1922 HirRelationExpr::Project { input, .. } => {
1923 f(input, depth)?;
1924 }
1925 HirRelationExpr::Map { input, .. } => {
1926 f(input, depth)?;
1927 }
1928 HirRelationExpr::Filter { input, .. } => {
1929 f(input, depth)?;
1930 }
1931 HirRelationExpr::Join { left, right, .. } => {
1932 f(left, depth)?;
1933 f(right, depth + 1)?;
1934 }
1935 HirRelationExpr::Reduce { input, .. } => {
1936 f(input, depth)?;
1937 }
1938 HirRelationExpr::Distinct { input } => {
1939 f(input, depth)?;
1940 }
1941 HirRelationExpr::TopK { input, .. } => {
1942 f(input, depth)?;
1943 }
1944 HirRelationExpr::Negate { input } => {
1945 f(input, depth)?;
1946 }
1947 HirRelationExpr::Threshold { input } => {
1948 f(input, depth)?;
1949 }
1950 HirRelationExpr::Union { base, inputs } => {
1951 f(base, depth)?;
1952 for input in inputs {
1953 f(input, depth)?;
1954 }
1955 }
1956 }
1957 Ok(())
1958 }
1959
1960 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1961 pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1967 where
1968 F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1969 {
1970 #[allow(deprecated)]
1971 self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1972 depth: usize|
1973 -> Result<(), E> {
1974 match e {
1975 HirRelationExpr::Join { on, .. } => {
1976 f(on, depth)?;
1977 }
1978 HirRelationExpr::Map { scalars, .. } => {
1979 for scalar in scalars {
1980 f(scalar, depth)?;
1981 }
1982 }
1983 HirRelationExpr::CallTable { exprs, .. } => {
1984 for expr in exprs {
1985 f(expr, depth)?;
1986 }
1987 }
1988 HirRelationExpr::Filter { predicates, .. } => {
1989 for predicate in predicates {
1990 f(predicate, depth)?;
1991 }
1992 }
1993 HirRelationExpr::Reduce { aggregates, .. } => {
1994 for aggregate in aggregates {
1995 f(&aggregate.expr, depth)?;
1996 }
1997 }
1998 HirRelationExpr::TopK { limit, offset, .. } => {
1999 if let Some(limit) = limit {
2000 f(limit, depth)?;
2001 }
2002 f(offset, depth)?;
2003 }
2004 HirRelationExpr::Union { .. }
2005 | HirRelationExpr::Let { .. }
2006 | HirRelationExpr::LetRec { .. }
2007 | HirRelationExpr::Project { .. }
2008 | HirRelationExpr::Distinct { .. }
2009 | HirRelationExpr::Negate { .. }
2010 | HirRelationExpr::Threshold { .. }
2011 | HirRelationExpr::Constant { .. }
2012 | HirRelationExpr::Get { .. } => (),
2013 }
2014 Ok(())
2015 })
2016 }
2017
2018 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2019 pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2021 where
2022 F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2023 {
2024 #[allow(deprecated)]
2025 self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2026 depth: usize|
2027 -> Result<(), E> {
2028 match e {
2029 HirRelationExpr::Join { on, .. } => {
2030 f(on, depth)?;
2031 }
2032 HirRelationExpr::Map { scalars, .. } => {
2033 for scalar in scalars.iter_mut() {
2034 f(scalar, depth)?;
2035 }
2036 }
2037 HirRelationExpr::CallTable { exprs, .. } => {
2038 for expr in exprs.iter_mut() {
2039 f(expr, depth)?;
2040 }
2041 }
2042 HirRelationExpr::Filter { predicates, .. } => {
2043 for predicate in predicates.iter_mut() {
2044 f(predicate, depth)?;
2045 }
2046 }
2047 HirRelationExpr::Reduce { aggregates, .. } => {
2048 for aggregate in aggregates.iter_mut() {
2049 f(&mut aggregate.expr, depth)?;
2050 }
2051 }
2052 HirRelationExpr::TopK { limit, offset, .. } => {
2053 if let Some(limit) = limit {
2054 f(limit, depth)?;
2055 }
2056 f(offset, depth)?;
2057 }
2058 HirRelationExpr::Union { .. }
2059 | HirRelationExpr::Let { .. }
2060 | HirRelationExpr::LetRec { .. }
2061 | HirRelationExpr::Project { .. }
2062 | HirRelationExpr::Distinct { .. }
2063 | HirRelationExpr::Negate { .. }
2064 | HirRelationExpr::Threshold { .. }
2065 | HirRelationExpr::Constant { .. }
2066 | HirRelationExpr::Get { .. } => (),
2067 }
2068 Ok(())
2069 })
2070 }
2071
2072 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2073 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2079 where
2080 F: FnMut(usize, &ColumnRef),
2081 {
2082 #[allow(deprecated)]
2083 let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2084 depth: usize|
2085 -> Result<(), ()> {
2086 e.visit_columns(depth, f);
2087 Ok(())
2088 });
2089 }
2090
2091 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2092 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2094 where
2095 F: FnMut(usize, &mut ColumnRef),
2096 {
2097 #[allow(deprecated)]
2098 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2099 depth: usize|
2100 -> Result<(), ()> {
2101 e.visit_columns_mut(depth, f);
2102 Ok(())
2103 });
2104 }
2105
2106 pub fn bind_parameters(
2109 &mut self,
2110 scx: &StatementContext,
2111 lifetime: QueryLifetime,
2112 params: &Params,
2113 ) -> Result<(), PlanError> {
2114 #[allow(deprecated)]
2115 self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2116 e.bind_parameters(scx, lifetime, params)
2117 })
2118 }
2119
2120 pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2121 let mut contains_parameters = false;
2122 #[allow(deprecated)]
2123 self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2124 if e.contains_parameters() {
2125 contains_parameters = true;
2126 }
2127 Ok::<(), PlanError>(())
2128 })?;
2129 Ok(contains_parameters)
2130 }
2131
2132 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2134 #[allow(deprecated)]
2135 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2136 depth: usize|
2137 -> Result<(), ()> {
2138 e.splice_parameters(params, depth);
2139 Ok(())
2140 });
2141 }
2142
2143 pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2145 let rows = rows
2146 .into_iter()
2147 .map(move |datums| Row::pack_slice(&datums))
2148 .collect();
2149 HirRelationExpr::Constant { rows, typ }
2150 }
2151
2152 pub fn finish_maintained(
2158 &mut self,
2159 finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2160 group_size_hints: GroupSizeHints,
2161 ) {
2162 if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2163 let old_finishing = mem::replace(
2164 finishing,
2165 HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2166 );
2167 *self = HirRelationExpr::top_k(
2168 std::mem::replace(
2169 self,
2170 HirRelationExpr::Constant {
2171 rows: vec![],
2172 typ: SqlRelationType::new(Vec::new()),
2173 },
2174 ),
2175 vec![],
2176 old_finishing.order_by,
2177 old_finishing.limit,
2178 old_finishing.offset,
2179 group_size_hints.limit_input_group_size,
2180 )
2181 .project(old_finishing.project);
2182 }
2183 }
2184
2185 pub fn trivial_row_set_finishing_hir(
2190 arity: usize,
2191 ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2192 RowSetFinishing {
2193 order_by: Vec::new(),
2194 limit: None,
2195 offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2196 project: (0..arity).collect(),
2197 }
2198 }
2199
2200 pub fn is_trivial_row_set_finishing_hir(
2205 rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2206 arity: usize,
2207 ) -> bool {
2208 rsf.limit.is_none()
2209 && rsf.order_by.is_empty()
2210 && rsf
2211 .offset
2212 .clone()
2213 .try_into_literal_int64()
2214 .is_ok_and(|o| o == 0)
2215 && rsf.project.iter().copied().eq(0..arity)
2216 }
2217
2218 pub fn could_run_expensive_function(&self) -> bool {
2227 let mut result = false;
2228 if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2229 use HirRelationExpr::*;
2230 use HirScalarExpr::*;
2231
2232 self.visit_children(|scalar: &HirScalarExpr| {
2233 if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2234 result |= match scalar {
2235 Column(..)
2236 | Literal(..)
2237 | CallUnmaterializable(..)
2238 | If { .. }
2239 | Parameter(..)
2240 | Select(..)
2241 | Exists(..) => false,
2242 CallUnary { .. }
2244 | CallBinary { .. }
2245 | CallVariadic { .. }
2246 | Windowing(..) => true,
2247 };
2248 }) {
2249 result = true;
2251 }
2252 });
2253
2254 result |= matches!(e, CallTable { .. } | Reduce { .. });
2257 }) {
2258 result = true;
2260 }
2261
2262 result
2263 }
2264
2265 pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2267 let mut contains = false;
2268 self.visit_post(&mut |expr| {
2269 expr.visit_children(|expr: &HirScalarExpr| {
2270 contains = contains || expr.contains_temporal()
2271 })
2272 })?;
2273 Ok(contains)
2274 }
2275
2276 pub fn contains_unmaterializable(&self) -> Result<bool, RecursionLimitError> {
2278 let mut contains = false;
2279 self.visit_post(&mut |expr| {
2280 expr.visit_children(|expr: &HirScalarExpr| {
2281 contains = contains || expr.contains_unmaterializable()
2282 })
2283 })?;
2284 Ok(contains)
2285 }
2286}
2287
2288impl CollectionPlan for HirRelationExpr {
2289 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2292 if let Self::Get {
2293 id: Id::Global(id), ..
2294 } = self
2295 {
2296 out.insert(*id);
2297 }
2298 self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2299 }
2300}
2301
2302impl VisitChildren<Self> for HirRelationExpr {
2303 fn visit_children<F>(&self, mut f: F)
2304 where
2305 F: FnMut(&Self),
2306 {
2307 VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2311 #[allow(deprecated)]
2312 Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2313 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2314 f(expr.as_ref())
2315 }
2316 _ => (),
2317 });
2318 });
2319
2320 use HirRelationExpr::*;
2321 match self {
2322 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2323 Let {
2324 name: _,
2325 id: _,
2326 value,
2327 body,
2328 } => {
2329 f(value);
2330 f(body);
2331 }
2332 LetRec {
2333 limit: _,
2334 bindings,
2335 body,
2336 } => {
2337 for (_, _, value, _) in bindings.iter() {
2338 f(value);
2339 }
2340 f(body);
2341 }
2342 Project { input, outputs: _ } => f(input),
2343 Map { input, scalars: _ } => {
2344 f(input);
2345 }
2346 CallTable { func: _, exprs: _ } => (),
2347 Filter {
2348 input,
2349 predicates: _,
2350 } => {
2351 f(input);
2352 }
2353 Join {
2354 left,
2355 right,
2356 on: _,
2357 kind: _,
2358 } => {
2359 f(left);
2360 f(right);
2361 }
2362 Reduce {
2363 input,
2364 group_key: _,
2365 aggregates: _,
2366 expected_group_size: _,
2367 } => {
2368 f(input);
2369 }
2370 Distinct { input }
2371 | TopK {
2372 input,
2373 group_key: _,
2374 order_key: _,
2375 limit: _,
2376 offset: _,
2377 expected_group_size: _,
2378 }
2379 | Negate { input }
2380 | Threshold { input } => {
2381 f(input);
2382 }
2383 Union { base, inputs } => {
2384 f(base);
2385 for input in inputs {
2386 f(input);
2387 }
2388 }
2389 }
2390 }
2391
2392 fn visit_mut_children<F>(&mut self, mut f: F)
2393 where
2394 F: FnMut(&mut Self),
2395 {
2396 VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2400 #[allow(deprecated)]
2401 Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2402 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2403 f(expr.as_mut())
2404 }
2405 _ => (),
2406 });
2407 });
2408
2409 use HirRelationExpr::*;
2410 match self {
2411 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2412 Let {
2413 name: _,
2414 id: _,
2415 value,
2416 body,
2417 } => {
2418 f(value);
2419 f(body);
2420 }
2421 LetRec {
2422 limit: _,
2423 bindings,
2424 body,
2425 } => {
2426 for (_, _, value, _) in bindings.iter_mut() {
2427 f(value);
2428 }
2429 f(body);
2430 }
2431 Project { input, outputs: _ } => f(input),
2432 Map { input, scalars: _ } => {
2433 f(input);
2434 }
2435 CallTable { func: _, exprs: _ } => (),
2436 Filter {
2437 input,
2438 predicates: _,
2439 } => {
2440 f(input);
2441 }
2442 Join {
2443 left,
2444 right,
2445 on: _,
2446 kind: _,
2447 } => {
2448 f(left);
2449 f(right);
2450 }
2451 Reduce {
2452 input,
2453 group_key: _,
2454 aggregates: _,
2455 expected_group_size: _,
2456 } => {
2457 f(input);
2458 }
2459 Distinct { input }
2460 | TopK {
2461 input,
2462 group_key: _,
2463 order_key: _,
2464 limit: _,
2465 offset: _,
2466 expected_group_size: _,
2467 }
2468 | Negate { input }
2469 | Threshold { input } => {
2470 f(input);
2471 }
2472 Union { base, inputs } => {
2473 f(base);
2474 for input in inputs {
2475 f(input);
2476 }
2477 }
2478 }
2479 }
2480
2481 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2482 where
2483 F: FnMut(&Self) -> Result<(), E>,
2484 E: From<RecursionLimitError>,
2485 {
2486 VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2490 Visit::try_visit_post(expr, &mut |expr| match expr {
2491 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2492 f(expr.as_ref())
2493 }
2494 _ => Ok(()),
2495 })
2496 })?;
2497
2498 use HirRelationExpr::*;
2499 match self {
2500 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2501 Let {
2502 name: _,
2503 id: _,
2504 value,
2505 body,
2506 } => {
2507 f(value)?;
2508 f(body)?;
2509 }
2510 LetRec {
2511 limit: _,
2512 bindings,
2513 body,
2514 } => {
2515 for (_, _, value, _) in bindings.iter() {
2516 f(value)?;
2517 }
2518 f(body)?;
2519 }
2520 Project { input, outputs: _ } => f(input)?,
2521 Map { input, scalars: _ } => {
2522 f(input)?;
2523 }
2524 CallTable { func: _, exprs: _ } => (),
2525 Filter {
2526 input,
2527 predicates: _,
2528 } => {
2529 f(input)?;
2530 }
2531 Join {
2532 left,
2533 right,
2534 on: _,
2535 kind: _,
2536 } => {
2537 f(left)?;
2538 f(right)?;
2539 }
2540 Reduce {
2541 input,
2542 group_key: _,
2543 aggregates: _,
2544 expected_group_size: _,
2545 } => {
2546 f(input)?;
2547 }
2548 Distinct { input }
2549 | TopK {
2550 input,
2551 group_key: _,
2552 order_key: _,
2553 limit: _,
2554 offset: _,
2555 expected_group_size: _,
2556 }
2557 | Negate { input }
2558 | Threshold { input } => {
2559 f(input)?;
2560 }
2561 Union { base, inputs } => {
2562 f(base)?;
2563 for input in inputs {
2564 f(input)?;
2565 }
2566 }
2567 }
2568 Ok(())
2569 }
2570
2571 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2572 where
2573 F: FnMut(&mut Self) -> Result<(), E>,
2574 E: From<RecursionLimitError>,
2575 {
2576 VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2580 Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2581 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2582 f(expr.as_mut())
2583 }
2584 _ => Ok(()),
2585 })
2586 })?;
2587
2588 use HirRelationExpr::*;
2589 match self {
2590 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2591 Let {
2592 name: _,
2593 id: _,
2594 value,
2595 body,
2596 } => {
2597 f(value)?;
2598 f(body)?;
2599 }
2600 LetRec {
2601 limit: _,
2602 bindings,
2603 body,
2604 } => {
2605 for (_, _, value, _) in bindings.iter_mut() {
2606 f(value)?;
2607 }
2608 f(body)?;
2609 }
2610 Project { input, outputs: _ } => f(input)?,
2611 Map { input, scalars: _ } => {
2612 f(input)?;
2613 }
2614 CallTable { func: _, exprs: _ } => (),
2615 Filter {
2616 input,
2617 predicates: _,
2618 } => {
2619 f(input)?;
2620 }
2621 Join {
2622 left,
2623 right,
2624 on: _,
2625 kind: _,
2626 } => {
2627 f(left)?;
2628 f(right)?;
2629 }
2630 Reduce {
2631 input,
2632 group_key: _,
2633 aggregates: _,
2634 expected_group_size: _,
2635 } => {
2636 f(input)?;
2637 }
2638 Distinct { input }
2639 | TopK {
2640 input,
2641 group_key: _,
2642 order_key: _,
2643 limit: _,
2644 offset: _,
2645 expected_group_size: _,
2646 }
2647 | Negate { input }
2648 | Threshold { input } => {
2649 f(input)?;
2650 }
2651 Union { base, inputs } => {
2652 f(base)?;
2653 for input in inputs {
2654 f(input)?;
2655 }
2656 }
2657 }
2658 Ok(())
2659 }
2660}
2661
2662impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2663 fn visit_children<F>(&self, mut f: F)
2664 where
2665 F: FnMut(&HirScalarExpr),
2666 {
2667 use HirRelationExpr::*;
2668 match self {
2669 Constant { rows: _, typ: _ }
2670 | Get { id: _, typ: _ }
2671 | Let {
2672 name: _,
2673 id: _,
2674 value: _,
2675 body: _,
2676 }
2677 | LetRec {
2678 limit: _,
2679 bindings: _,
2680 body: _,
2681 }
2682 | Project {
2683 input: _,
2684 outputs: _,
2685 } => (),
2686 Map { input: _, scalars } => {
2687 for scalar in scalars {
2688 f(scalar);
2689 }
2690 }
2691 CallTable { func: _, exprs } => {
2692 for expr in exprs {
2693 f(expr);
2694 }
2695 }
2696 Filter {
2697 input: _,
2698 predicates,
2699 } => {
2700 for predicate in predicates {
2701 f(predicate);
2702 }
2703 }
2704 Join {
2705 left: _,
2706 right: _,
2707 on,
2708 kind: _,
2709 } => f(on),
2710 Reduce {
2711 input: _,
2712 group_key: _,
2713 aggregates,
2714 expected_group_size: _,
2715 } => {
2716 for aggregate in aggregates {
2717 f(aggregate.expr.as_ref());
2718 }
2719 }
2720 TopK {
2721 input: _,
2722 group_key: _,
2723 order_key: _,
2724 limit,
2725 offset,
2726 expected_group_size: _,
2727 } => {
2728 if let Some(limit) = limit {
2729 f(limit)
2730 }
2731 f(offset)
2732 }
2733 Distinct { input: _ }
2734 | Negate { input: _ }
2735 | Threshold { input: _ }
2736 | Union { base: _, inputs: _ } => (),
2737 }
2738 }
2739
2740 fn visit_mut_children<F>(&mut self, mut f: F)
2741 where
2742 F: FnMut(&mut HirScalarExpr),
2743 {
2744 use HirRelationExpr::*;
2745 match self {
2746 Constant { rows: _, typ: _ }
2747 | Get { id: _, typ: _ }
2748 | Let {
2749 name: _,
2750 id: _,
2751 value: _,
2752 body: _,
2753 }
2754 | LetRec {
2755 limit: _,
2756 bindings: _,
2757 body: _,
2758 }
2759 | Project {
2760 input: _,
2761 outputs: _,
2762 } => (),
2763 Map { input: _, scalars } => {
2764 for scalar in scalars {
2765 f(scalar);
2766 }
2767 }
2768 CallTable { func: _, exprs } => {
2769 for expr in exprs {
2770 f(expr);
2771 }
2772 }
2773 Filter {
2774 input: _,
2775 predicates,
2776 } => {
2777 for predicate in predicates {
2778 f(predicate);
2779 }
2780 }
2781 Join {
2782 left: _,
2783 right: _,
2784 on,
2785 kind: _,
2786 } => f(on),
2787 Reduce {
2788 input: _,
2789 group_key: _,
2790 aggregates,
2791 expected_group_size: _,
2792 } => {
2793 for aggregate in aggregates {
2794 f(aggregate.expr.as_mut());
2795 }
2796 }
2797 TopK {
2798 input: _,
2799 group_key: _,
2800 order_key: _,
2801 limit,
2802 offset,
2803 expected_group_size: _,
2804 } => {
2805 if let Some(limit) = limit {
2806 f(limit)
2807 }
2808 f(offset)
2809 }
2810 Distinct { input: _ }
2811 | Negate { input: _ }
2812 | Threshold { input: _ }
2813 | Union { base: _, inputs: _ } => (),
2814 }
2815 }
2816
2817 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2818 where
2819 F: FnMut(&HirScalarExpr) -> Result<(), E>,
2820 E: From<RecursionLimitError>,
2821 {
2822 use HirRelationExpr::*;
2823 match self {
2824 Constant { rows: _, typ: _ }
2825 | Get { id: _, typ: _ }
2826 | Let {
2827 name: _,
2828 id: _,
2829 value: _,
2830 body: _,
2831 }
2832 | LetRec {
2833 limit: _,
2834 bindings: _,
2835 body: _,
2836 }
2837 | Project {
2838 input: _,
2839 outputs: _,
2840 } => (),
2841 Map { input: _, scalars } => {
2842 for scalar in scalars {
2843 f(scalar)?;
2844 }
2845 }
2846 CallTable { func: _, exprs } => {
2847 for expr in exprs {
2848 f(expr)?;
2849 }
2850 }
2851 Filter {
2852 input: _,
2853 predicates,
2854 } => {
2855 for predicate in predicates {
2856 f(predicate)?;
2857 }
2858 }
2859 Join {
2860 left: _,
2861 right: _,
2862 on,
2863 kind: _,
2864 } => f(on)?,
2865 Reduce {
2866 input: _,
2867 group_key: _,
2868 aggregates,
2869 expected_group_size: _,
2870 } => {
2871 for aggregate in aggregates {
2872 f(aggregate.expr.as_ref())?;
2873 }
2874 }
2875 TopK {
2876 input: _,
2877 group_key: _,
2878 order_key: _,
2879 limit,
2880 offset,
2881 expected_group_size: _,
2882 } => {
2883 if let Some(limit) = limit {
2884 f(limit)?
2885 }
2886 f(offset)?
2887 }
2888 Distinct { input: _ }
2889 | Negate { input: _ }
2890 | Threshold { input: _ }
2891 | Union { base: _, inputs: _ } => (),
2892 }
2893 Ok(())
2894 }
2895
2896 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2897 where
2898 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2899 E: From<RecursionLimitError>,
2900 {
2901 use HirRelationExpr::*;
2902 match self {
2903 Constant { rows: _, typ: _ }
2904 | Get { id: _, typ: _ }
2905 | Let {
2906 name: _,
2907 id: _,
2908 value: _,
2909 body: _,
2910 }
2911 | LetRec {
2912 limit: _,
2913 bindings: _,
2914 body: _,
2915 }
2916 | Project {
2917 input: _,
2918 outputs: _,
2919 } => (),
2920 Map { input: _, scalars } => {
2921 for scalar in scalars {
2922 f(scalar)?;
2923 }
2924 }
2925 CallTable { func: _, exprs } => {
2926 for expr in exprs {
2927 f(expr)?;
2928 }
2929 }
2930 Filter {
2931 input: _,
2932 predicates,
2933 } => {
2934 for predicate in predicates {
2935 f(predicate)?;
2936 }
2937 }
2938 Join {
2939 left: _,
2940 right: _,
2941 on,
2942 kind: _,
2943 } => f(on)?,
2944 Reduce {
2945 input: _,
2946 group_key: _,
2947 aggregates,
2948 expected_group_size: _,
2949 } => {
2950 for aggregate in aggregates {
2951 f(aggregate.expr.as_mut())?;
2952 }
2953 }
2954 TopK {
2955 input: _,
2956 group_key: _,
2957 order_key: _,
2958 limit,
2959 offset,
2960 expected_group_size: _,
2961 } => {
2962 if let Some(limit) = limit {
2963 f(limit)?
2964 }
2965 f(offset)?
2966 }
2967 Distinct { input: _ }
2968 | Negate { input: _ }
2969 | Threshold { input: _ }
2970 | Union { base: _, inputs: _ } => (),
2971 }
2972 Ok(())
2973 }
2974}
2975
2976impl HirScalarExpr {
2977 pub fn name(&self) -> Option<Arc<str>> {
2978 use HirScalarExpr::*;
2979 match self {
2980 Column(_, name)
2981 | Parameter(_, name)
2982 | Literal(_, _, name)
2983 | CallUnmaterializable(_, name)
2984 | CallUnary { name, .. }
2985 | CallBinary { name, .. }
2986 | CallVariadic { name, .. }
2987 | If { name, .. }
2988 | Exists(_, name)
2989 | Select(_, name)
2990 | Windowing(_, name) => name.0.clone(),
2991 }
2992 }
2993
2994 pub fn bind_parameters(
2997 &mut self,
2998 scx: &StatementContext,
2999 lifetime: QueryLifetime,
3000 params: &Params,
3001 ) -> Result<(), PlanError> {
3002 #[allow(deprecated)]
3003 self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
3004 if let HirScalarExpr::Parameter(n, name) = e {
3005 let datum = match params.datums.iter().nth(*n - 1) {
3006 None => return Err(PlanError::UnknownParameter(*n)),
3007 Some(datum) => datum,
3008 };
3009 let scalar_type = ¶ms.execute_types[*n - 1];
3010 let row = Row::pack([datum]);
3011 let column_type = scalar_type.clone().nullable(datum.is_null());
3012
3013 let name = if let Some(name) = &name.0 {
3014 Some(Arc::clone(name))
3015 } else {
3016 Some(Arc::from(format!("${n}")))
3017 };
3018
3019 let qcx = QueryContext::root(scx, lifetime);
3020 let ecx = execute_expr_context(&qcx);
3021
3022 *e = plan_cast(
3023 &ecx,
3024 *EXECUTE_CAST_CONTEXT,
3025 HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3026 ¶ms.expected_types[*n - 1],
3027 )
3028 .expect("checked in plan_params");
3029 }
3030 Ok(())
3031 })
3032 }
3033
3034 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3045 #[allow(deprecated)]
3046 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3047 e: &mut HirScalarExpr|
3048 -> Result<(), ()> {
3049 if let HirScalarExpr::Parameter(i, _name) = e {
3050 *e = params[*i - 1].clone();
3051 e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3054 if col.level >= d {
3055 col.level += depth
3056 }
3057 });
3058 }
3059 Ok(())
3060 });
3061 }
3062
3063 pub fn contains_temporal(&self) -> bool {
3065 let mut contains = false;
3066 #[allow(deprecated)]
3067 self.visit_post_nolimit(&mut |e| {
3068 if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3069 contains = true;
3070 }
3071 });
3072 contains
3073 }
3074
3075 pub fn contains_unmaterializable(&self) -> bool {
3077 let mut contains = false;
3078 #[allow(deprecated)]
3079 self.visit_post_nolimit(&mut |e| {
3080 if let Self::CallUnmaterializable(_, _) = e {
3081 contains = true;
3082 }
3083 });
3084 contains
3085 }
3086
3087 pub fn column(index: usize) -> HirScalarExpr {
3091 HirScalarExpr::Column(
3092 ColumnRef {
3093 level: 0,
3094 column: index,
3095 },
3096 TreatAsEqual(None),
3097 )
3098 }
3099
3100 pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3102 HirScalarExpr::Column(cr, TreatAsEqual(None))
3103 }
3104
3105 pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3108 HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3109 }
3110
3111 pub fn parameter(n: usize) -> HirScalarExpr {
3112 HirScalarExpr::Parameter(n, TreatAsEqual(None))
3113 }
3114
3115 pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3116 let col_type = scalar_type.nullable(datum.is_null());
3117 soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3118 let row = Row::pack([datum]);
3119 HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3120 }
3121
3122 pub fn literal_true() -> HirScalarExpr {
3123 HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3124 }
3125
3126 pub fn literal_false() -> HirScalarExpr {
3127 HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3128 }
3129
3130 pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3131 HirScalarExpr::literal(Datum::Null, scalar_type)
3132 }
3133
3134 pub fn literal_1d_array(
3135 datums: Vec<Datum>,
3136 element_scalar_type: SqlScalarType,
3137 ) -> Result<HirScalarExpr, PlanError> {
3138 let scalar_type = match element_scalar_type {
3139 SqlScalarType::Array(_) => {
3140 sql_bail!("cannot build array from array type");
3141 }
3142 typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3143 };
3144
3145 let mut row = Row::default();
3146 row.packer()
3147 .try_push_array(
3148 &[ArrayDimension {
3149 lower_bound: 1,
3150 length: datums.len(),
3151 }],
3152 datums,
3153 )
3154 .expect("array constructed to be valid");
3155
3156 Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3157 }
3158
3159 pub fn as_literal(&self) -> Option<Datum<'_>> {
3160 if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3161 Some(row.unpack_first())
3162 } else {
3163 None
3164 }
3165 }
3166
3167 pub fn is_literal_true(&self) -> bool {
3168 Some(Datum::True) == self.as_literal()
3169 }
3170
3171 pub fn is_literal_false(&self) -> bool {
3172 Some(Datum::False) == self.as_literal()
3173 }
3174
3175 pub fn is_literal_null(&self) -> bool {
3176 Some(Datum::Null) == self.as_literal()
3177 }
3178
3179 pub fn is_constant(&self) -> bool {
3182 let mut worklist = vec![self];
3183 while let Some(expr) = worklist.pop() {
3184 match expr {
3185 Self::Literal(..) => {
3186 }
3188 Self::CallUnary { expr, .. } => {
3189 worklist.push(expr);
3190 }
3191 Self::CallBinary {
3192 func: _,
3193 expr1,
3194 expr2,
3195 name: _,
3196 } => {
3197 worklist.push(expr1);
3198 worklist.push(expr2);
3199 }
3200 Self::CallVariadic {
3201 func: _,
3202 exprs,
3203 name: _,
3204 } => {
3205 worklist.extend(exprs.iter());
3206 }
3207 Self::If {
3209 cond,
3210 then,
3211 els,
3212 name: _,
3213 } => {
3214 worklist.push(cond);
3215 worklist.push(then);
3216 worklist.push(els);
3217 }
3218 _ => {
3219 return false; }
3221 }
3222 }
3223 true
3224 }
3225
3226 pub fn call_unary(self, func: UnaryFunc) -> Self {
3227 HirScalarExpr::CallUnary {
3228 func,
3229 expr: Box::new(self),
3230 name: NameMetadata::default(),
3231 }
3232 }
3233
3234 pub fn call_binary<B: Into<BinaryFunc>>(self, other: Self, func: B) -> Self {
3235 HirScalarExpr::CallBinary {
3236 func: func.into(),
3237 expr1: Box::new(self),
3238 expr2: Box::new(other),
3239 name: NameMetadata::default(),
3240 }
3241 }
3242
3243 pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3244 HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3245 }
3246
3247 pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3248 HirScalarExpr::CallVariadic {
3249 func,
3250 exprs,
3251 name: NameMetadata::default(),
3252 }
3253 }
3254
3255 pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3256 HirScalarExpr::If {
3257 cond: Box::new(cond),
3258 then: Box::new(then),
3259 els: Box::new(els),
3260 name: NameMetadata::default(),
3261 }
3262 }
3263
3264 pub fn windowing(expr: WindowExpr) -> Self {
3265 HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3266 }
3267
3268 pub fn or(self, other: Self) -> Self {
3269 HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3270 }
3271
3272 pub fn and(self, other: Self) -> Self {
3273 HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3274 }
3275
3276 pub fn not(self) -> Self {
3277 self.call_unary(UnaryFunc::Not(func::Not))
3278 }
3279
3280 pub fn call_is_null(self) -> Self {
3281 self.call_unary(UnaryFunc::IsNull(func::IsNull))
3282 }
3283
3284 pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3286 match args.len() {
3287 0 => HirScalarExpr::literal_true(), 1 => args.swap_remove(0),
3289 _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3290 }
3291 }
3292
3293 pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3295 match args.len() {
3296 0 => HirScalarExpr::literal_false(), 1 => args.swap_remove(0),
3298 _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3299 }
3300 }
3301
3302 pub fn take(&mut self) -> Self {
3303 mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3304 }
3305
3306 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3307 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3313 where
3314 F: FnMut(usize, &ColumnRef),
3315 {
3316 #[allow(deprecated)]
3317 let _ = self.visit_recursively(depth, &mut |depth: usize,
3318 e: &HirScalarExpr|
3319 -> Result<(), ()> {
3320 if let HirScalarExpr::Column(col, _name) = e {
3321 f(depth, col)
3322 }
3323 Ok(())
3324 });
3325 }
3326
3327 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3328 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3330 where
3331 F: FnMut(usize, &mut ColumnRef),
3332 {
3333 #[allow(deprecated)]
3334 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3335 e: &mut HirScalarExpr|
3336 -> Result<(), ()> {
3337 if let HirScalarExpr::Column(col, _name) = e {
3338 f(depth, col)
3339 }
3340 Ok(())
3341 });
3342 }
3343
3344 pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3350 where
3351 F: FnMut(usize),
3352 {
3353 #[allow(deprecated)]
3354 let _ = self.visit_recursively(0, &mut |depth: usize,
3355 e: &HirScalarExpr|
3356 -> Result<(), ()> {
3357 if let HirScalarExpr::Column(col, _name) = e {
3358 if col.level == depth {
3359 f(col.column)
3360 }
3361 }
3362 Ok(())
3363 });
3364 }
3365
3366 pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3368 where
3369 F: FnMut(&mut usize),
3370 {
3371 #[allow(deprecated)]
3372 let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3373 e: &mut HirScalarExpr|
3374 -> Result<(), ()> {
3375 if let HirScalarExpr::Column(col, _name) = e {
3376 if col.level == depth {
3377 f(&mut col.column)
3378 }
3379 }
3380 Ok(())
3381 });
3382 }
3383
3384 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3385 pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3389 where
3390 F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3391 {
3392 match self {
3393 HirScalarExpr::Literal(..)
3394 | HirScalarExpr::Parameter(..)
3395 | HirScalarExpr::CallUnmaterializable(..)
3396 | HirScalarExpr::Column(..) => (),
3397 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3398 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3399 expr1.visit_recursively(depth, f)?;
3400 expr2.visit_recursively(depth, f)?;
3401 }
3402 HirScalarExpr::CallVariadic { exprs, .. } => {
3403 for expr in exprs {
3404 expr.visit_recursively(depth, f)?;
3405 }
3406 }
3407 HirScalarExpr::If {
3408 cond,
3409 then,
3410 els,
3411 name: _,
3412 } => {
3413 cond.visit_recursively(depth, f)?;
3414 then.visit_recursively(depth, f)?;
3415 els.visit_recursively(depth, f)?;
3416 }
3417 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3418 #[allow(deprecated)]
3419 expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3420 e.visit_recursively(depth, f)
3421 })?;
3422 }
3423 HirScalarExpr::Windowing(expr, _name) => {
3424 expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3425 }
3426 }
3427 f(depth, self)
3428 }
3429
3430 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3431 pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3433 where
3434 F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3435 {
3436 match self {
3437 HirScalarExpr::Literal(..)
3438 | HirScalarExpr::Parameter(..)
3439 | HirScalarExpr::CallUnmaterializable(..)
3440 | HirScalarExpr::Column(..) => (),
3441 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3442 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3443 expr1.visit_recursively_mut(depth, f)?;
3444 expr2.visit_recursively_mut(depth, f)?;
3445 }
3446 HirScalarExpr::CallVariadic { exprs, .. } => {
3447 for expr in exprs {
3448 expr.visit_recursively_mut(depth, f)?;
3449 }
3450 }
3451 HirScalarExpr::If {
3452 cond,
3453 then,
3454 els,
3455 name: _,
3456 } => {
3457 cond.visit_recursively_mut(depth, f)?;
3458 then.visit_recursively_mut(depth, f)?;
3459 els.visit_recursively_mut(depth, f)?;
3460 }
3461 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3462 #[allow(deprecated)]
3463 expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3464 e.visit_recursively_mut(depth, f)
3465 })?;
3466 }
3467 HirScalarExpr::Windowing(expr, _name) => {
3468 expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3469 }
3470 }
3471 f(depth, self)
3472 }
3473
3474 fn simplify_to_literal(self) -> Option<Row> {
3483 let mut expr = self.lower_uncorrelated().ok()?;
3484 expr.reduce(&[]);
3485 match expr {
3486 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3487 _ => None,
3488 }
3489 }
3490
3491 fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3504 let mut expr = self.lower_uncorrelated().map_err(|err| {
3505 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3506 })?;
3507 expr.reduce(&[]);
3508 match expr {
3509 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3510 mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3511 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3512 ),
3513 _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3514 "Not a constant".to_string(),
3515 )),
3516 }
3517 }
3518
3519 pub fn into_literal_int64(self) -> Option<i64> {
3528 self.simplify_to_literal().and_then(|row| {
3529 let datum = row.unpack_first();
3530 if datum.is_null() {
3531 None
3532 } else {
3533 Some(datum.unwrap_int64())
3534 }
3535 })
3536 }
3537
3538 pub fn into_literal_string(self) -> Option<String> {
3547 self.simplify_to_literal().and_then(|row| {
3548 let datum = row.unpack_first();
3549 if datum.is_null() {
3550 None
3551 } else {
3552 Some(datum.unwrap_str().to_owned())
3553 }
3554 })
3555 }
3556
3557 pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3570 self.simplify_to_literal().and_then(|row| {
3571 let datum = row.unpack_first();
3572 if datum.is_null() {
3573 None
3574 } else {
3575 Some(datum.unwrap_mz_timestamp())
3576 }
3577 })
3578 }
3579
3580 pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3592 if !self.is_constant() {
3598 return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3599 "Expected a constant expression, got {}",
3600 self
3601 )));
3602 }
3603 self.clone()
3604 .simplify_to_literal_with_result()
3605 .and_then(|row| {
3606 let datum = row.unpack_first();
3607 if datum.is_null() {
3608 Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3609 "Expected an expression that evaluates to a non-null value, got {}",
3610 self
3611 )))
3612 } else {
3613 Ok(datum.unwrap_int64())
3614 }
3615 })
3616 }
3617
3618 pub fn contains_parameters(&self) -> bool {
3619 let mut contains_parameters = false;
3620 #[allow(deprecated)]
3621 let _ = self.visit_recursively(0, &mut |_depth: usize,
3622 expr: &HirScalarExpr|
3623 -> Result<(), ()> {
3624 if let HirScalarExpr::Parameter(..) = expr {
3625 contains_parameters = true;
3626 }
3627 Ok(())
3628 });
3629 contains_parameters
3630 }
3631}
3632
3633impl VisitChildren<Self> for HirScalarExpr {
3634 fn visit_children<F>(&self, mut f: F)
3635 where
3636 F: FnMut(&Self),
3637 {
3638 use HirScalarExpr::*;
3639 match self {
3640 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3641 CallUnary { expr, .. } => f(expr),
3642 CallBinary { expr1, expr2, .. } => {
3643 f(expr1);
3644 f(expr2);
3645 }
3646 CallVariadic { exprs, .. } => {
3647 for expr in exprs {
3648 f(expr);
3649 }
3650 }
3651 If {
3652 cond,
3653 then,
3654 els,
3655 name: _,
3656 } => {
3657 f(cond);
3658 f(then);
3659 f(els);
3660 }
3661 Exists(..) | Select(..) => (),
3662 Windowing(expr, _name) => expr.visit_children(f),
3663 }
3664 }
3665
3666 fn visit_mut_children<F>(&mut self, mut f: F)
3667 where
3668 F: FnMut(&mut Self),
3669 {
3670 use HirScalarExpr::*;
3671 match self {
3672 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3673 CallUnary { expr, .. } => f(expr),
3674 CallBinary { expr1, expr2, .. } => {
3675 f(expr1);
3676 f(expr2);
3677 }
3678 CallVariadic { exprs, .. } => {
3679 for expr in exprs {
3680 f(expr);
3681 }
3682 }
3683 If {
3684 cond,
3685 then,
3686 els,
3687 name: _,
3688 } => {
3689 f(cond);
3690 f(then);
3691 f(els);
3692 }
3693 Exists(..) | Select(..) => (),
3694 Windowing(expr, _name) => expr.visit_mut_children(f),
3695 }
3696 }
3697
3698 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3699 where
3700 F: FnMut(&Self) -> Result<(), E>,
3701 E: From<RecursionLimitError>,
3702 {
3703 use HirScalarExpr::*;
3704 match self {
3705 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3706 CallUnary { expr, .. } => f(expr)?,
3707 CallBinary { expr1, expr2, .. } => {
3708 f(expr1)?;
3709 f(expr2)?;
3710 }
3711 CallVariadic { exprs, .. } => {
3712 for expr in exprs {
3713 f(expr)?;
3714 }
3715 }
3716 If {
3717 cond,
3718 then,
3719 els,
3720 name: _,
3721 } => {
3722 f(cond)?;
3723 f(then)?;
3724 f(els)?;
3725 }
3726 Exists(..) | Select(..) => (),
3727 Windowing(expr, _name) => expr.try_visit_children(f)?,
3728 }
3729 Ok(())
3730 }
3731
3732 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3733 where
3734 F: FnMut(&mut Self) -> Result<(), E>,
3735 E: From<RecursionLimitError>,
3736 {
3737 use HirScalarExpr::*;
3738 match self {
3739 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3740 CallUnary { expr, .. } => f(expr)?,
3741 CallBinary { expr1, expr2, .. } => {
3742 f(expr1)?;
3743 f(expr2)?;
3744 }
3745 CallVariadic { exprs, .. } => {
3746 for expr in exprs {
3747 f(expr)?;
3748 }
3749 }
3750 If {
3751 cond,
3752 then,
3753 els,
3754 name: _,
3755 } => {
3756 f(cond)?;
3757 f(then)?;
3758 f(els)?;
3759 }
3760 Exists(..) | Select(..) => (),
3761 Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3762 }
3763 Ok(())
3764 }
3765}
3766
3767impl AbstractExpr for HirScalarExpr {
3768 type Type = SqlColumnType;
3769
3770 fn typ(
3771 &self,
3772 outers: &[SqlRelationType],
3773 inner: &SqlRelationType,
3774 params: &BTreeMap<usize, SqlScalarType>,
3775 ) -> Self::Type {
3776 stack::maybe_grow(|| match self {
3777 HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3778 if *level == 0 {
3779 inner.column_types[*column].clone()
3780 } else {
3781 outers[*level - 1].column_types[*column].clone()
3782 }
3783 }
3784 HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3785 HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3786 HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3787 HirScalarExpr::CallUnary {
3788 expr,
3789 func,
3790 name: _,
3791 } => func.output_type(expr.typ(outers, inner, params)),
3792 HirScalarExpr::CallBinary {
3793 expr1,
3794 expr2,
3795 func,
3796 name: _,
3797 } => func.output_type(
3798 expr1.typ(outers, inner, params),
3799 expr2.typ(outers, inner, params),
3800 ),
3801 HirScalarExpr::CallVariadic {
3802 exprs,
3803 func,
3804 name: _,
3805 } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3806 HirScalarExpr::If {
3807 cond: _,
3808 then,
3809 els,
3810 name: _,
3811 } => {
3812 let then_type = then.typ(outers, inner, params);
3813 let else_type = els.typ(outers, inner, params);
3814 then_type.union(&else_type).unwrap()
3815 }
3816 HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3817 HirScalarExpr::Select(expr, _name) => {
3818 let mut outers = outers.to_vec();
3819 outers.insert(0, inner.clone());
3820 expr.typ(&outers, params)
3821 .column_types
3822 .into_element()
3823 .nullable(true)
3824 }
3825 HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3826 })
3827 }
3828}
3829
3830impl AggregateExpr {
3831 pub fn typ(
3832 &self,
3833 outers: &[SqlRelationType],
3834 inner: &SqlRelationType,
3835 params: &BTreeMap<usize, SqlScalarType>,
3836 ) -> SqlColumnType {
3837 self.func.output_type(self.expr.typ(outers, inner, params))
3838 }
3839
3840 pub fn is_count_asterisk(&self) -> bool {
3848 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3849 }
3850}