1use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::sync::Arc;
17use std::{fmt, mem};
18
19use itertools::Itertools;
20use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
21use mz_expr::visit::{Visit, VisitChildren};
22use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
23use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
25pub use mz_expr::{
26 BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
27};
28use mz_ore::collections::CollectionExt;
29use mz_ore::error::ErrorExt;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::separated;
32use mz_ore::treat_as_equal::TreatAsEqual;
33use mz_ore::{soft_assert_or_log, stack};
34use mz_repr::adt::array::ArrayDimension;
35use mz_repr::adt::numeric::NumericMaxScale;
36use mz_repr::*;
37use serde::{Deserialize, Serialize};
38
39use crate::plan::error::PlanError;
40use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41use crate::plan::typeconv::{self, CastContext, plan_cast};
42use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
43
44use super::plan_utils::GroupSizeHints;
45
46#[allow(missing_debug_implementations)]
47pub struct Hir;
48
49impl IR for Hir {
50 type Relation = HirRelationExpr;
51 type Scalar = HirScalarExpr;
52}
53
54impl AlgExcept for Hir {
55 fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
56 if *all {
57 let rhs = rhs.negate();
58 HirRelationExpr::union(lhs, rhs).threshold()
59 } else {
60 let lhs = lhs.distinct();
61 let rhs = rhs.distinct().negate();
62 HirRelationExpr::union(lhs, rhs).threshold()
63 }
64 }
65
66 fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
67 let mut result = None;
68
69 use HirRelationExpr::*;
70 if let Threshold { input } = expr {
71 if let Union { base: lhs, inputs } = input.as_ref() {
72 if let [rhs] = &inputs[..] {
73 if let Negate { input: rhs } = rhs {
74 match (lhs.as_ref(), rhs.as_ref()) {
75 (Distinct { input: lhs }, Distinct { input: rhs }) => {
76 let all = false;
77 let lhs = lhs.as_ref();
78 let rhs = rhs.as_ref();
79 result = Some(Except { all, lhs, rhs })
80 }
81 (lhs, rhs) => {
82 let all = true;
83 result = Some(Except { all, lhs, rhs })
84 }
85 }
86 }
87 }
88 }
89 }
90
91 result
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
96pub enum HirRelationExpr {
98 Constant {
99 rows: Vec<Row>,
100 typ: SqlRelationType,
101 },
102 Get {
103 id: mz_expr::Id,
104 typ: SqlRelationType,
105 },
106 LetRec {
108 limit: Option<LetRecLimit>,
110 bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, SqlRelationType)>,
112 body: Box<HirRelationExpr>,
114 },
115 Let {
117 name: String,
118 id: mz_expr::LocalId,
120 value: Box<HirRelationExpr>,
122 body: Box<HirRelationExpr>,
124 },
125 Project {
126 input: Box<HirRelationExpr>,
127 outputs: Vec<usize>,
128 },
129 Map {
130 input: Box<HirRelationExpr>,
131 scalars: Vec<HirScalarExpr>,
132 },
133 CallTable {
134 func: TableFunc,
135 exprs: Vec<HirScalarExpr>,
136 },
137 Filter {
138 input: Box<HirRelationExpr>,
139 predicates: Vec<HirScalarExpr>,
140 },
141 Join {
144 left: Box<HirRelationExpr>,
145 right: Box<HirRelationExpr>,
146 on: HirScalarExpr,
147 kind: JoinKind,
148 },
149 Reduce {
153 input: Box<HirRelationExpr>,
154 group_key: Vec<usize>,
155 aggregates: Vec<AggregateExpr>,
156 expected_group_size: Option<u64>,
157 },
158 Distinct {
159 input: Box<HirRelationExpr>,
160 },
161 TopK {
163 input: Box<HirRelationExpr>,
165 group_key: Vec<usize>,
167 order_key: Vec<ColumnOrder>,
169 limit: Option<HirScalarExpr>,
178 offset: HirScalarExpr,
183 expected_group_size: Option<u64>,
185 },
186 Negate {
187 input: Box<HirRelationExpr>,
188 },
189 Threshold {
191 input: Box<HirRelationExpr>,
192 },
193 Union {
194 base: Box<HirRelationExpr>,
195 inputs: Vec<HirRelationExpr>,
196 },
197}
198
199pub type NameMetadata = TreatAsEqual<Option<Arc<str>>>;
201
202#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
203pub enum HirScalarExpr {
205 Column(ColumnRef, NameMetadata),
209 Parameter(usize, NameMetadata),
210 Literal(Row, SqlColumnType, NameMetadata),
211 CallUnmaterializable(UnmaterializableFunc, NameMetadata),
212 CallUnary {
213 func: UnaryFunc,
214 expr: Box<HirScalarExpr>,
215 name: NameMetadata,
216 },
217 CallBinary {
218 func: BinaryFunc,
219 expr1: Box<HirScalarExpr>,
220 expr2: Box<HirScalarExpr>,
221 name: NameMetadata,
222 },
223 CallVariadic {
224 func: VariadicFunc,
225 exprs: Vec<HirScalarExpr>,
226 name: NameMetadata,
227 },
228 If {
229 cond: Box<HirScalarExpr>,
230 then: Box<HirScalarExpr>,
231 els: Box<HirScalarExpr>,
232 name: NameMetadata,
233 },
234 Exists(Box<HirRelationExpr>, NameMetadata),
236 Select(Box<HirRelationExpr>, NameMetadata),
241 Windowing(WindowExpr, NameMetadata),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
245pub struct WindowExpr {
248 pub func: WindowExprType,
249 pub partition_by: Vec<HirScalarExpr>,
250 pub order_by: Vec<HirScalarExpr>,
261}
262
263impl WindowExpr {
264 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
265 where
266 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
267 {
268 #[allow(deprecated)]
269 self.func.visit_expressions(f)?;
270 for expr in self.partition_by.iter() {
271 f(expr)?;
272 }
273 for expr in self.order_by.iter() {
274 f(expr)?;
275 }
276 Ok(())
277 }
278
279 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
280 where
281 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
282 {
283 #[allow(deprecated)]
284 self.func.visit_expressions_mut(f)?;
285 for expr in self.partition_by.iter_mut() {
286 f(expr)?;
287 }
288 for expr in self.order_by.iter_mut() {
289 f(expr)?;
290 }
291 Ok(())
292 }
293}
294
295impl VisitChildren<HirScalarExpr> for WindowExpr {
296 fn visit_children<F>(&self, mut f: F)
297 where
298 F: FnMut(&HirScalarExpr),
299 {
300 self.func.visit_children(&mut f);
301 for expr in self.partition_by.iter() {
302 f(expr);
303 }
304 for expr in self.order_by.iter() {
305 f(expr);
306 }
307 }
308
309 fn visit_mut_children<F>(&mut self, mut f: F)
310 where
311 F: FnMut(&mut HirScalarExpr),
312 {
313 self.func.visit_mut_children(&mut f);
314 for expr in self.partition_by.iter_mut() {
315 f(expr);
316 }
317 for expr in self.order_by.iter_mut() {
318 f(expr);
319 }
320 }
321
322 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
323 where
324 F: FnMut(&HirScalarExpr) -> Result<(), E>,
325 E: From<RecursionLimitError>,
326 {
327 self.func.try_visit_children(&mut f)?;
328 for expr in self.partition_by.iter() {
329 f(expr)?;
330 }
331 for expr in self.order_by.iter() {
332 f(expr)?;
333 }
334 Ok(())
335 }
336
337 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
338 where
339 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
340 E: From<RecursionLimitError>,
341 {
342 self.func.try_visit_mut_children(&mut f)?;
343 for expr in self.partition_by.iter_mut() {
344 f(expr)?;
345 }
346 for expr in self.order_by.iter_mut() {
347 f(expr)?;
348 }
349 Ok(())
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
354pub enum WindowExprType {
371 Scalar(ScalarWindowExpr),
372 Value(ValueWindowExpr),
373 Aggregate(AggregateWindowExpr),
374}
375
376impl WindowExprType {
377 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
378 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
379 where
380 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
381 {
382 #[allow(deprecated)]
383 match self {
384 Self::Scalar(expr) => expr.visit_expressions(f),
385 Self::Value(expr) => expr.visit_expressions(f),
386 Self::Aggregate(expr) => expr.visit_expressions(f),
387 }
388 }
389
390 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
391 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
392 where
393 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
394 {
395 #[allow(deprecated)]
396 match self {
397 Self::Scalar(expr) => expr.visit_expressions_mut(f),
398 Self::Value(expr) => expr.visit_expressions_mut(f),
399 Self::Aggregate(expr) => expr.visit_expressions_mut(f),
400 }
401 }
402
403 fn typ(
404 &self,
405 outers: &[SqlRelationType],
406 inner: &SqlRelationType,
407 params: &BTreeMap<usize, SqlScalarType>,
408 ) -> SqlColumnType {
409 match self {
410 Self::Scalar(expr) => expr.typ(outers, inner, params),
411 Self::Value(expr) => expr.typ(outers, inner, params),
412 Self::Aggregate(expr) => expr.typ(outers, inner, params),
413 }
414 }
415}
416
417impl VisitChildren<HirScalarExpr> for WindowExprType {
418 fn visit_children<F>(&self, f: F)
419 where
420 F: FnMut(&HirScalarExpr),
421 {
422 match self {
423 Self::Scalar(_) => (),
424 Self::Value(expr) => expr.visit_children(f),
425 Self::Aggregate(expr) => expr.visit_children(f),
426 }
427 }
428
429 fn visit_mut_children<F>(&mut self, f: F)
430 where
431 F: FnMut(&mut HirScalarExpr),
432 {
433 match self {
434 Self::Scalar(_) => (),
435 Self::Value(expr) => expr.visit_mut_children(f),
436 Self::Aggregate(expr) => expr.visit_mut_children(f),
437 }
438 }
439
440 fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
441 where
442 F: FnMut(&HirScalarExpr) -> Result<(), E>,
443 E: From<RecursionLimitError>,
444 {
445 match self {
446 Self::Scalar(_) => Ok(()),
447 Self::Value(expr) => expr.try_visit_children(f),
448 Self::Aggregate(expr) => expr.try_visit_children(f),
449 }
450 }
451
452 fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
453 where
454 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
455 E: From<RecursionLimitError>,
456 {
457 match self {
458 Self::Scalar(_) => Ok(()),
459 Self::Value(expr) => expr.try_visit_mut_children(f),
460 Self::Aggregate(expr) => expr.try_visit_mut_children(f),
461 }
462 }
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
466pub struct ScalarWindowExpr {
467 pub func: ScalarWindowFunc,
468 pub order_by: Vec<ColumnOrder>,
469}
470
471impl ScalarWindowExpr {
472 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
473 pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
474 where
475 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
476 {
477 match self.func {
478 ScalarWindowFunc::RowNumber => {}
479 ScalarWindowFunc::Rank => {}
480 ScalarWindowFunc::DenseRank => {}
481 }
482 Ok(())
483 }
484
485 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
486 pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
487 where
488 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
489 {
490 match self.func {
491 ScalarWindowFunc::RowNumber => {}
492 ScalarWindowFunc::Rank => {}
493 ScalarWindowFunc::DenseRank => {}
494 }
495 Ok(())
496 }
497
498 fn typ(
499 &self,
500 _outers: &[SqlRelationType],
501 _inner: &SqlRelationType,
502 _params: &BTreeMap<usize, SqlScalarType>,
503 ) -> SqlColumnType {
504 self.func.output_type()
505 }
506
507 pub fn into_expr(self) -> mz_expr::AggregateFunc {
508 match self.func {
509 ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
510 order_by: self.order_by,
511 },
512 ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
513 order_by: self.order_by,
514 },
515 ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
516 order_by: self.order_by,
517 },
518 }
519 }
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
523pub enum ScalarWindowFunc {
525 RowNumber,
526 Rank,
527 DenseRank,
528}
529
530impl Display for ScalarWindowFunc {
531 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
532 match self {
533 ScalarWindowFunc::RowNumber => write!(f, "row_number"),
534 ScalarWindowFunc::Rank => write!(f, "rank"),
535 ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
536 }
537 }
538}
539
540impl ScalarWindowFunc {
541 pub fn output_type(&self) -> SqlColumnType {
542 match self {
543 ScalarWindowFunc::RowNumber => SqlScalarType::Int64.nullable(false),
544 ScalarWindowFunc::Rank => SqlScalarType::Int64.nullable(false),
545 ScalarWindowFunc::DenseRank => SqlScalarType::Int64.nullable(false),
546 }
547 }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
551pub struct ValueWindowExpr {
552 pub func: ValueWindowFunc,
553 pub args: Box<HirScalarExpr>,
559 pub order_by: Vec<ColumnOrder>,
561 pub window_frame: WindowFrame,
562 pub ignore_nulls: bool,
563}
564
565impl Display for ValueWindowFunc {
566 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567 match self {
568 ValueWindowFunc::Lag => write!(f, "lag"),
569 ValueWindowFunc::Lead => write!(f, "lead"),
570 ValueWindowFunc::FirstValue => write!(f, "first_value"),
571 ValueWindowFunc::LastValue => write!(f, "last_value"),
572 ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
573 }
574 }
575}
576
577impl ValueWindowExpr {
578 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
579 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
580 where
581 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
582 {
583 f(&self.args)
584 }
585
586 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
587 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
588 where
589 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
590 {
591 f(&mut self.args)
592 }
593
594 fn typ(
595 &self,
596 outers: &[SqlRelationType],
597 inner: &SqlRelationType,
598 params: &BTreeMap<usize, SqlScalarType>,
599 ) -> SqlColumnType {
600 self.func.output_type(self.args.typ(outers, inner, params))
601 }
602
603 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
605 (
606 self.args,
607 self.func
608 .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
609 )
610 }
611}
612
613impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
614 fn visit_children<F>(&self, mut f: F)
615 where
616 F: FnMut(&HirScalarExpr),
617 {
618 f(&self.args)
619 }
620
621 fn visit_mut_children<F>(&mut self, mut f: F)
622 where
623 F: FnMut(&mut HirScalarExpr),
624 {
625 f(&mut self.args)
626 }
627
628 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
629 where
630 F: FnMut(&HirScalarExpr) -> Result<(), E>,
631 E: From<RecursionLimitError>,
632 {
633 f(&self.args)
634 }
635
636 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
637 where
638 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
639 E: From<RecursionLimitError>,
640 {
641 f(&mut self.args)
642 }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
646pub enum ValueWindowFunc {
648 Lag,
649 Lead,
650 FirstValue,
651 LastValue,
652 Fused(Vec<ValueWindowFunc>),
653}
654
655impl ValueWindowFunc {
656 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
657 match self {
658 ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
659 input_type.scalar_type.unwrap_record_element_type()[0]
661 .clone()
662 .nullable(true)
663 }
664 ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
665 input_type.scalar_type.nullable(true)
666 }
667 ValueWindowFunc::Fused(funcs) => {
668 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
669 SqlScalarType::Record {
670 fields: funcs
671 .iter()
672 .zip_eq(input_types)
673 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
674 .collect(),
675 custom_id: None,
676 }
677 .nullable(false)
678 }
679 }
680 }
681
682 pub fn into_expr(
683 self,
684 order_by: Vec<ColumnOrder>,
685 window_frame: WindowFrame,
686 ignore_nulls: bool,
687 ) -> mz_expr::AggregateFunc {
688 match self {
689 ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
691 order_by,
692 lag_lead: mz_expr::LagLeadType::Lag,
693 ignore_nulls,
694 },
695 ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
696 order_by,
697 lag_lead: mz_expr::LagLeadType::Lead,
698 ignore_nulls,
699 },
700 ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
701 order_by,
702 window_frame,
703 },
704 ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
705 order_by,
706 window_frame,
707 },
708 ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
709 funcs: funcs
710 .into_iter()
711 .map(|func| {
712 func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
713 })
714 .collect(),
715 order_by,
716 },
717 }
718 }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
722pub struct AggregateWindowExpr {
723 pub aggregate_expr: AggregateExpr,
724 pub order_by: Vec<ColumnOrder>,
725 pub window_frame: WindowFrame,
726}
727
728impl AggregateWindowExpr {
729 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
730 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
731 where
732 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
733 {
734 f(&self.aggregate_expr.expr)
735 }
736
737 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
738 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
739 where
740 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
741 {
742 f(&mut self.aggregate_expr.expr)
743 }
744
745 fn typ(
746 &self,
747 outers: &[SqlRelationType],
748 inner: &SqlRelationType,
749 params: &BTreeMap<usize, SqlScalarType>,
750 ) -> SqlColumnType {
751 self.aggregate_expr
752 .func
753 .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
754 }
755
756 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
757 if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
758 (
759 self.aggregate_expr.expr,
760 FusedWindowAggregate {
761 wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
762 order_by: self.order_by,
763 window_frame: self.window_frame,
764 },
765 )
766 } else {
767 (
768 self.aggregate_expr.expr,
769 WindowAggregate {
770 wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
771 order_by: self.order_by,
772 window_frame: self.window_frame,
773 },
774 )
775 }
776 }
777}
778
779impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
780 fn visit_children<F>(&self, mut f: F)
781 where
782 F: FnMut(&HirScalarExpr),
783 {
784 f(&self.aggregate_expr.expr)
785 }
786
787 fn visit_mut_children<F>(&mut self, mut f: F)
788 where
789 F: FnMut(&mut HirScalarExpr),
790 {
791 f(&mut self.aggregate_expr.expr)
792 }
793
794 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
795 where
796 F: FnMut(&HirScalarExpr) -> Result<(), E>,
797 E: From<RecursionLimitError>,
798 {
799 f(&self.aggregate_expr.expr)
800 }
801
802 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
803 where
804 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
805 E: From<RecursionLimitError>,
806 {
807 f(&mut self.aggregate_expr.expr)
808 }
809}
810
811#[derive(Clone, Debug)]
836pub enum CoercibleScalarExpr {
837 Coerced(HirScalarExpr),
838 Parameter(usize),
839 LiteralNull,
840 LiteralString(String),
841 LiteralRecord(Vec<CoercibleScalarExpr>),
842}
843
844impl CoercibleScalarExpr {
845 pub fn type_as(
846 self,
847 ecx: &ExprContext,
848 ty: &SqlScalarType,
849 ) -> Result<HirScalarExpr, PlanError> {
850 let expr = typeconv::plan_coerce(ecx, self, ty)?;
851 let expr_ty = ecx.scalar_type(&expr);
852 if ty != &expr_ty {
853 sql_bail!(
854 "{} must have type {}, not type {}",
855 ecx.name,
856 ecx.humanize_scalar_type(ty, false),
857 ecx.humanize_scalar_type(&expr_ty, false),
858 );
859 }
860 Ok(expr)
861 }
862
863 pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
864 typeconv::plan_coerce(ecx, self, &SqlScalarType::String)
865 }
866
867 pub fn cast_to(
868 self,
869 ecx: &ExprContext,
870 ccx: CastContext,
871 ty: &SqlScalarType,
872 ) -> Result<HirScalarExpr, PlanError> {
873 let expr = typeconv::plan_coerce(ecx, self, ty)?;
874 typeconv::plan_cast(ecx, ccx, expr, ty)
875 }
876}
877
878#[derive(Clone, Debug)]
880pub enum CoercibleColumnType {
881 Coerced(SqlColumnType),
882 Record(Vec<CoercibleColumnType>),
883 Uncoerced,
884}
885
886impl CoercibleColumnType {
887 pub fn nullable(&self) -> bool {
889 match self {
890 CoercibleColumnType::Coerced(ct) => ct.nullable,
892
893 CoercibleColumnType::Record(_) => false,
895
896 CoercibleColumnType::Uncoerced => true,
899 }
900 }
901}
902
903#[derive(Clone, Debug)]
905pub enum CoercibleScalarType {
906 Coerced(SqlScalarType),
907 Record(Vec<CoercibleColumnType>),
908 Uncoerced,
909}
910
911impl CoercibleScalarType {
912 pub fn is_coerced(&self) -> bool {
914 matches!(self, CoercibleScalarType::Coerced(_))
915 }
916
917 pub fn as_coerced(&self) -> Option<&SqlScalarType> {
919 match self {
920 CoercibleScalarType::Coerced(t) => Some(t),
921 _ => None,
922 }
923 }
924
925 pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
928 where
929 F: FnOnce(SqlScalarType) -> SqlScalarType,
930 {
931 match self {
932 CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
933 _ => self,
934 }
935 }
936
937 pub fn force_coerced_if_record(&mut self) {
944 fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> SqlScalarType {
945 let mut fields = vec![];
946 for (i, uf) in uncoerced_fields.enumerate() {
947 let name = ColumnName::from(format!("f{}", i + 1));
948 let ty = match uf {
949 CoercibleColumnType::Coerced(ty) => ty,
950 CoercibleColumnType::Record(mut fields) => {
951 convert(fields.drain(..)).nullable(false)
952 }
953 CoercibleColumnType::Uncoerced => SqlScalarType::String.nullable(true),
954 };
955 fields.push((name, ty))
956 }
957 SqlScalarType::Record {
958 fields: fields.into(),
959 custom_id: None,
960 }
961 }
962
963 if let CoercibleScalarType::Record(fields) = self {
964 *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
965 }
966 }
967}
968
969pub trait AbstractExpr {
973 type Type: AbstractColumnType;
974
975 fn typ(
977 &self,
978 outers: &[SqlRelationType],
979 inner: &SqlRelationType,
980 params: &BTreeMap<usize, SqlScalarType>,
981 ) -> Self::Type;
982}
983
984impl AbstractExpr for CoercibleScalarExpr {
985 type Type = CoercibleColumnType;
986
987 fn typ(
988 &self,
989 outers: &[SqlRelationType],
990 inner: &SqlRelationType,
991 params: &BTreeMap<usize, SqlScalarType>,
992 ) -> Self::Type {
993 match self {
994 CoercibleScalarExpr::Coerced(expr) => {
995 CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
996 }
997 CoercibleScalarExpr::LiteralRecord(scalars) => {
998 let fields = scalars
999 .iter()
1000 .map(|s| s.typ(outers, inner, params))
1001 .collect();
1002 CoercibleColumnType::Record(fields)
1003 }
1004 _ => CoercibleColumnType::Uncoerced,
1005 }
1006 }
1007}
1008
1009pub trait AbstractColumnType {
1014 type AbstractScalarType;
1015
1016 fn scalar_type(self) -> Self::AbstractScalarType;
1019}
1020
1021impl AbstractColumnType for SqlColumnType {
1022 type AbstractScalarType = SqlScalarType;
1023
1024 fn scalar_type(self) -> Self::AbstractScalarType {
1025 self.scalar_type
1026 }
1027}
1028
1029impl AbstractColumnType for CoercibleColumnType {
1030 type AbstractScalarType = CoercibleScalarType;
1031
1032 fn scalar_type(self) -> Self::AbstractScalarType {
1033 match self {
1034 CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1035 CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1036 CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1037 }
1038 }
1039}
1040
1041impl From<HirScalarExpr> for CoercibleScalarExpr {
1042 fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1043 CoercibleScalarExpr::Coerced(expr)
1044 }
1045}
1046
1047#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1062pub struct ColumnRef {
1063 pub level: usize,
1065 pub column: usize,
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1070pub enum JoinKind {
1071 Inner,
1072 LeftOuter,
1073 RightOuter,
1074 FullOuter,
1075}
1076
1077impl fmt::Display for JoinKind {
1078 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1079 write!(
1080 f,
1081 "{}",
1082 match self {
1083 JoinKind::Inner => "Inner",
1084 JoinKind::LeftOuter => "LeftOuter",
1085 JoinKind::RightOuter => "RightOuter",
1086 JoinKind::FullOuter => "FullOuter",
1087 }
1088 )
1089 }
1090}
1091
1092impl JoinKind {
1093 pub fn can_be_correlated(&self) -> bool {
1094 match self {
1095 JoinKind::Inner | JoinKind::LeftOuter => true,
1096 JoinKind::RightOuter | JoinKind::FullOuter => false,
1097 }
1098 }
1099}
1100
1101#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1102pub struct AggregateExpr {
1103 pub func: AggregateFunc,
1104 pub expr: Box<HirScalarExpr>,
1105 pub distinct: bool,
1106}
1107
1108#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1116pub enum AggregateFunc {
1117 MaxNumeric,
1118 MaxInt16,
1119 MaxInt32,
1120 MaxInt64,
1121 MaxUInt16,
1122 MaxUInt32,
1123 MaxUInt64,
1124 MaxMzTimestamp,
1125 MaxFloat32,
1126 MaxFloat64,
1127 MaxBool,
1128 MaxString,
1129 MaxDate,
1130 MaxTimestamp,
1131 MaxTimestampTz,
1132 MaxInterval,
1133 MaxTime,
1134 MinNumeric,
1135 MinInt16,
1136 MinInt32,
1137 MinInt64,
1138 MinUInt16,
1139 MinUInt32,
1140 MinUInt64,
1141 MinMzTimestamp,
1142 MinFloat32,
1143 MinFloat64,
1144 MinBool,
1145 MinString,
1146 MinDate,
1147 MinTimestamp,
1148 MinTimestampTz,
1149 MinInterval,
1150 MinTime,
1151 SumInt16,
1152 SumInt32,
1153 SumInt64,
1154 SumUInt16,
1155 SumUInt32,
1156 SumUInt64,
1157 SumFloat32,
1158 SumFloat64,
1159 SumNumeric,
1160 Count,
1161 Any,
1162 All,
1163 JsonbAgg {
1170 order_by: Vec<ColumnOrder>,
1171 },
1172 JsonbObjectAgg {
1175 order_by: Vec<ColumnOrder>,
1176 },
1177 MapAgg {
1181 order_by: Vec<ColumnOrder>,
1182 value_type: SqlScalarType,
1183 },
1184 ArrayConcat {
1187 order_by: Vec<ColumnOrder>,
1188 },
1189 ListConcat {
1192 order_by: Vec<ColumnOrder>,
1193 },
1194 StringAgg {
1195 order_by: Vec<ColumnOrder>,
1196 },
1197 FusedWindowAgg {
1203 funcs: Vec<AggregateFunc>,
1204 },
1205 Dummy,
1210}
1211
1212impl AggregateFunc {
1213 pub fn into_expr(self) -> mz_expr::AggregateFunc {
1215 match self {
1216 AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1217 AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1218 AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1219 AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1220 AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1221 AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1222 AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1223 AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1224 AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1225 AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1226 AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1227 AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1228 AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1229 AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1230 AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1231 AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1232 AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1233 AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1234 AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1235 AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1236 AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1237 AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1238 AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1239 AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1240 AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1241 AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1242 AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1243 AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1244 AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1245 AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1246 AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1247 AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1248 AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1249 AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1250 AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1251 AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1252 AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1253 AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1254 AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1255 AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1256 AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1257 AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1258 AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1259 AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1260 AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1261 AggregateFunc::All => mz_expr::AggregateFunc::All,
1262 AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1263 AggregateFunc::JsonbObjectAgg { order_by } => {
1264 mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1265 }
1266 AggregateFunc::MapAgg {
1267 order_by,
1268 value_type,
1269 } => mz_expr::AggregateFunc::MapAgg {
1270 order_by,
1271 value_type,
1272 },
1273 AggregateFunc::ArrayConcat { order_by } => {
1274 mz_expr::AggregateFunc::ArrayConcat { order_by }
1275 }
1276 AggregateFunc::ListConcat { order_by } => {
1277 mz_expr::AggregateFunc::ListConcat { order_by }
1278 }
1279 AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1280 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1283 panic!("into_expr called on FusedWindowAgg")
1284 }
1285 AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1286 }
1287 }
1288
1289 pub fn identity_datum(&self) -> Datum<'static> {
1296 match self {
1297 AggregateFunc::Any => Datum::False,
1298 AggregateFunc::All => Datum::True,
1299 AggregateFunc::Dummy => Datum::Dummy,
1300 AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1301 AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1302 AggregateFunc::MaxNumeric
1303 | AggregateFunc::MaxInt16
1304 | AggregateFunc::MaxInt32
1305 | AggregateFunc::MaxInt64
1306 | AggregateFunc::MaxUInt16
1307 | AggregateFunc::MaxUInt32
1308 | AggregateFunc::MaxUInt64
1309 | AggregateFunc::MaxMzTimestamp
1310 | AggregateFunc::MaxFloat32
1311 | AggregateFunc::MaxFloat64
1312 | AggregateFunc::MaxBool
1313 | AggregateFunc::MaxString
1314 | AggregateFunc::MaxDate
1315 | AggregateFunc::MaxTimestamp
1316 | AggregateFunc::MaxTimestampTz
1317 | AggregateFunc::MaxInterval
1318 | AggregateFunc::MaxTime
1319 | AggregateFunc::MinNumeric
1320 | AggregateFunc::MinInt16
1321 | AggregateFunc::MinInt32
1322 | AggregateFunc::MinInt64
1323 | AggregateFunc::MinUInt16
1324 | AggregateFunc::MinUInt32
1325 | AggregateFunc::MinUInt64
1326 | AggregateFunc::MinMzTimestamp
1327 | AggregateFunc::MinFloat32
1328 | AggregateFunc::MinFloat64
1329 | AggregateFunc::MinBool
1330 | AggregateFunc::MinString
1331 | AggregateFunc::MinDate
1332 | AggregateFunc::MinTimestamp
1333 | AggregateFunc::MinTimestampTz
1334 | AggregateFunc::MinInterval
1335 | AggregateFunc::MinTime
1336 | AggregateFunc::SumInt16
1337 | AggregateFunc::SumInt32
1338 | AggregateFunc::SumInt64
1339 | AggregateFunc::SumUInt16
1340 | AggregateFunc::SumUInt32
1341 | AggregateFunc::SumUInt64
1342 | AggregateFunc::SumFloat32
1343 | AggregateFunc::SumFloat64
1344 | AggregateFunc::SumNumeric
1345 | AggregateFunc::Count
1346 | AggregateFunc::JsonbAgg { .. }
1347 | AggregateFunc::JsonbObjectAgg { .. }
1348 | AggregateFunc::MapAgg { .. }
1349 | AggregateFunc::StringAgg { .. } => Datum::Null,
1350 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1351 panic!("FusedWindowAgg doesn't have an identity_datum")
1361 }
1362 }
1363 }
1364
1365 pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
1371 let scalar_type = match self {
1372 AggregateFunc::Count => SqlScalarType::Int64,
1373 AggregateFunc::Any => SqlScalarType::Bool,
1374 AggregateFunc::All => SqlScalarType::Bool,
1375 AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
1376 AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
1377 AggregateFunc::StringAgg { .. } => SqlScalarType::String,
1378 AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => SqlScalarType::Int64,
1379 AggregateFunc::SumInt64 => SqlScalarType::Numeric {
1380 max_scale: Some(NumericMaxScale::ZERO),
1381 },
1382 AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
1383 AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
1384 max_scale: Some(NumericMaxScale::ZERO),
1385 },
1386 AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
1387 value_type: Box::new(value_type.clone()),
1388 custom_id: None,
1389 },
1390 AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1391 match input_type.scalar_type {
1392 SqlScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1394 _ => unreachable!(),
1395 }
1396 }
1397 AggregateFunc::MaxNumeric
1398 | AggregateFunc::MaxInt16
1399 | AggregateFunc::MaxInt32
1400 | AggregateFunc::MaxInt64
1401 | AggregateFunc::MaxUInt16
1402 | AggregateFunc::MaxUInt32
1403 | AggregateFunc::MaxUInt64
1404 | AggregateFunc::MaxMzTimestamp
1405 | AggregateFunc::MaxFloat32
1406 | AggregateFunc::MaxFloat64
1407 | AggregateFunc::MaxBool
1408 | AggregateFunc::MaxString
1409 | AggregateFunc::MaxDate
1410 | AggregateFunc::MaxTimestamp
1411 | AggregateFunc::MaxTimestampTz
1412 | AggregateFunc::MaxInterval
1413 | AggregateFunc::MaxTime
1414 | AggregateFunc::MinNumeric
1415 | AggregateFunc::MinInt16
1416 | AggregateFunc::MinInt32
1417 | AggregateFunc::MinInt64
1418 | AggregateFunc::MinUInt16
1419 | AggregateFunc::MinUInt32
1420 | AggregateFunc::MinUInt64
1421 | AggregateFunc::MinMzTimestamp
1422 | AggregateFunc::MinFloat32
1423 | AggregateFunc::MinFloat64
1424 | AggregateFunc::MinBool
1425 | AggregateFunc::MinString
1426 | AggregateFunc::MinDate
1427 | AggregateFunc::MinTimestamp
1428 | AggregateFunc::MinTimestampTz
1429 | AggregateFunc::MinInterval
1430 | AggregateFunc::MinTime
1431 | AggregateFunc::SumFloat32
1432 | AggregateFunc::SumFloat64
1433 | AggregateFunc::SumNumeric
1434 | AggregateFunc::Dummy => input_type.scalar_type,
1435 AggregateFunc::FusedWindowAgg { funcs } => {
1436 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1437 SqlScalarType::Record {
1438 fields: funcs
1439 .iter()
1440 .zip_eq(input_types)
1441 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1442 .collect(),
1443 custom_id: None,
1444 }
1445 }
1446 };
1447 let nullable = !matches!(self, AggregateFunc::Count);
1449 scalar_type.nullable(nullable)
1450 }
1451
1452 pub fn is_order_sensitive(&self) -> bool {
1453 use AggregateFunc::*;
1454 matches!(
1455 self,
1456 JsonbAgg { .. }
1457 | JsonbObjectAgg { .. }
1458 | MapAgg { .. }
1459 | ArrayConcat { .. }
1460 | ListConcat { .. }
1461 | StringAgg { .. }
1462 )
1463 }
1464}
1465
1466impl HirRelationExpr {
1467 pub fn typ(
1468 &self,
1469 outers: &[SqlRelationType],
1470 params: &BTreeMap<usize, SqlScalarType>,
1471 ) -> SqlRelationType {
1472 stack::maybe_grow(|| match self {
1473 HirRelationExpr::Constant { typ, .. } => typ.clone(),
1474 HirRelationExpr::Get { typ, .. } => typ.clone(),
1475 HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1476 HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1477 HirRelationExpr::Project { input, outputs } => {
1478 let input_typ = input.typ(outers, params);
1479 SqlRelationType::new(
1480 outputs
1481 .iter()
1482 .map(|&i| input_typ.column_types[i].clone())
1483 .collect(),
1484 )
1485 }
1486 HirRelationExpr::Map { input, scalars } => {
1487 let mut typ = input.typ(outers, params);
1488 for scalar in scalars {
1489 typ.column_types.push(scalar.typ(outers, &typ, params));
1490 }
1491 typ
1492 }
1493 HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1494 HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1495 input.typ(outers, params)
1496 }
1497 HirRelationExpr::Join {
1498 left, right, kind, ..
1499 } => {
1500 let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1501 let right_nullable =
1502 matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1503 let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1504 let nullable = t.nullable || left_nullable;
1505 t.nullable(nullable)
1506 });
1507 let mut outers = outers.to_vec();
1508 outers.insert(0, SqlRelationType::new(lt.clone().collect()));
1509 let rt = right
1510 .typ(&outers, params)
1511 .column_types
1512 .into_iter()
1513 .map(|t| {
1514 let nullable = t.nullable || right_nullable;
1515 t.nullable(nullable)
1516 });
1517 SqlRelationType::new(lt.chain(rt).collect())
1518 }
1519 HirRelationExpr::Reduce {
1520 input,
1521 group_key,
1522 aggregates,
1523 expected_group_size: _,
1524 } => {
1525 let input_typ = input.typ(outers, params);
1526 let mut column_types = group_key
1527 .iter()
1528 .map(|&i| input_typ.column_types[i].clone())
1529 .collect::<Vec<_>>();
1530 for agg in aggregates {
1531 column_types.push(agg.typ(outers, &input_typ, params));
1532 }
1533 SqlRelationType::new(column_types)
1535 }
1536 HirRelationExpr::Distinct { input }
1538 | HirRelationExpr::Negate { input }
1539 | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1540 HirRelationExpr::Union { base, inputs } => {
1541 let mut base_cols = base.typ(outers, params).column_types;
1542 for input in inputs {
1543 for (base_col, col) in base_cols
1544 .iter_mut()
1545 .zip_eq(input.typ(outers, params).column_types)
1546 {
1547 *base_col = base_col.union(&col).unwrap();
1548 }
1549 }
1550 SqlRelationType::new(base_cols)
1551 }
1552 })
1553 }
1554
1555 pub fn arity(&self) -> usize {
1556 match self {
1557 HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1558 HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1559 HirRelationExpr::Let { body, .. } => body.arity(),
1560 HirRelationExpr::LetRec { body, .. } => body.arity(),
1561 HirRelationExpr::Project { outputs, .. } => outputs.len(),
1562 HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1563 HirRelationExpr::CallTable { func, exprs: _ } => func.output_arity(),
1564 HirRelationExpr::Filter { input, .. }
1565 | HirRelationExpr::TopK { input, .. }
1566 | HirRelationExpr::Distinct { input }
1567 | HirRelationExpr::Negate { input }
1568 | HirRelationExpr::Threshold { input } => input.arity(),
1569 HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1570 HirRelationExpr::Union { base, .. } => base.arity(),
1571 HirRelationExpr::Reduce {
1572 group_key,
1573 aggregates,
1574 ..
1575 } => group_key.len() + aggregates.len(),
1576 }
1577 }
1578
1579 pub fn as_const(&self) -> Option<(&Vec<Row>, &SqlRelationType)> {
1581 match self {
1582 Self::Constant { rows, typ } => Some((rows, typ)),
1583 _ => None,
1584 }
1585 }
1586
1587 pub fn is_correlated(&self) -> bool {
1590 let mut correlated = false;
1591 #[allow(deprecated)]
1592 self.visit_columns(0, &mut |depth, col| {
1593 if col.level > depth && col.level - depth == 1 {
1594 correlated = true;
1595 }
1596 });
1597 correlated
1598 }
1599
1600 pub fn is_join_identity(&self) -> bool {
1601 match self {
1602 HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1603 _ => false,
1604 }
1605 }
1606
1607 pub fn project(self, outputs: Vec<usize>) -> Self {
1608 if outputs.iter().copied().eq(0..self.arity()) {
1609 self
1611 } else {
1612 HirRelationExpr::Project {
1613 input: Box::new(self),
1614 outputs,
1615 }
1616 }
1617 }
1618
1619 pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1620 if scalars.is_empty() {
1621 self
1623 } else if let HirRelationExpr::Map {
1624 scalars: old_scalars,
1625 input: _,
1626 } = &mut self
1627 {
1628 old_scalars.extend(scalars);
1630 self
1631 } else {
1632 HirRelationExpr::Map {
1633 input: Box::new(self),
1634 scalars,
1635 }
1636 }
1637 }
1638
1639 pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1640 if let HirRelationExpr::Filter {
1641 input: _,
1642 predicates,
1643 } = &mut self
1644 {
1645 predicates.extend(preds);
1646 predicates.sort();
1647 predicates.dedup();
1648 self
1649 } else {
1650 preds.sort();
1651 preds.dedup();
1652 HirRelationExpr::Filter {
1653 input: Box::new(self),
1654 predicates: preds,
1655 }
1656 }
1657 }
1658
1659 pub fn reduce(
1660 self,
1661 group_key: Vec<usize>,
1662 aggregates: Vec<AggregateExpr>,
1663 expected_group_size: Option<u64>,
1664 ) -> Self {
1665 HirRelationExpr::Reduce {
1666 input: Box::new(self),
1667 group_key,
1668 aggregates,
1669 expected_group_size,
1670 }
1671 }
1672
1673 pub fn top_k(
1674 self,
1675 group_key: Vec<usize>,
1676 order_key: Vec<ColumnOrder>,
1677 limit: Option<HirScalarExpr>,
1678 offset: HirScalarExpr,
1679 expected_group_size: Option<u64>,
1680 ) -> Self {
1681 HirRelationExpr::TopK {
1682 input: Box::new(self),
1683 group_key,
1684 order_key,
1685 limit,
1686 offset,
1687 expected_group_size,
1688 }
1689 }
1690
1691 pub fn negate(self) -> Self {
1692 if let HirRelationExpr::Negate { input } = self {
1693 *input
1694 } else {
1695 HirRelationExpr::Negate {
1696 input: Box::new(self),
1697 }
1698 }
1699 }
1700
1701 pub fn distinct(self) -> Self {
1702 if let HirRelationExpr::Distinct { .. } = self {
1703 self
1704 } else {
1705 HirRelationExpr::Distinct {
1706 input: Box::new(self),
1707 }
1708 }
1709 }
1710
1711 pub fn threshold(self) -> Self {
1712 if let HirRelationExpr::Threshold { .. } = self {
1713 self
1714 } else {
1715 HirRelationExpr::Threshold {
1716 input: Box::new(self),
1717 }
1718 }
1719 }
1720
1721 pub fn union(self, other: Self) -> Self {
1722 let mut terms = Vec::new();
1723 if let HirRelationExpr::Union { base, inputs } = self {
1724 terms.push(*base);
1725 terms.extend(inputs);
1726 } else {
1727 terms.push(self);
1728 }
1729 if let HirRelationExpr::Union { base, inputs } = other {
1730 terms.push(*base);
1731 terms.extend(inputs);
1732 } else {
1733 terms.push(other);
1734 }
1735 HirRelationExpr::Union {
1736 base: Box::new(terms.remove(0)),
1737 inputs: terms,
1738 }
1739 }
1740
1741 pub fn exists(self) -> HirScalarExpr {
1742 HirScalarExpr::Exists(Box::new(self), NameMetadata::default())
1743 }
1744
1745 pub fn select(self) -> HirScalarExpr {
1746 HirScalarExpr::Select(Box::new(self), NameMetadata::default())
1747 }
1748
1749 pub fn join(
1750 self,
1751 mut right: HirRelationExpr,
1752 on: HirScalarExpr,
1753 kind: JoinKind,
1754 ) -> HirRelationExpr {
1755 if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1756 {
1757 #[allow(deprecated)]
1761 right.visit_columns_mut(0, &mut |depth, col| {
1762 if col.level > depth {
1763 col.level -= 1;
1764 }
1765 });
1766 right
1767 } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1768 self
1769 } else {
1770 HirRelationExpr::Join {
1771 left: Box::new(self),
1772 right: Box::new(right),
1773 on,
1774 kind,
1775 }
1776 }
1777 }
1778
1779 pub fn take(&mut self) -> HirRelationExpr {
1780 mem::replace(
1781 self,
1782 HirRelationExpr::constant(vec![], SqlRelationType::new(Vec::new())),
1783 )
1784 }
1785
1786 #[deprecated = "Use `Visit::visit_post`."]
1787 pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1788 where
1789 F: FnMut(&'a Self, usize),
1790 {
1791 #[allow(deprecated)]
1792 let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1793 depth: usize|
1794 -> Result<(), ()> {
1795 f(e, depth);
1796 Ok(())
1797 });
1798 }
1799
1800 #[deprecated = "Use `Visit::try_visit_post`."]
1801 pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1802 where
1803 F: FnMut(&'a Self, usize) -> Result<(), E>,
1804 {
1805 #[allow(deprecated)]
1806 self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1807 e.visit_fallible(depth, f)
1808 })?;
1809 f(self, depth)
1810 }
1811
1812 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1813 pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1814 where
1815 F: FnMut(&'a Self, usize) -> Result<(), E>,
1816 {
1817 match self {
1818 HirRelationExpr::Constant { .. }
1819 | HirRelationExpr::Get { .. }
1820 | HirRelationExpr::CallTable { .. } => (),
1821 HirRelationExpr::Let { body, value, .. } => {
1822 f(value, depth)?;
1823 f(body, depth)?;
1824 }
1825 HirRelationExpr::LetRec {
1826 limit: _,
1827 bindings,
1828 body,
1829 } => {
1830 for (_, _, value, _) in bindings.iter() {
1831 f(value, depth)?;
1832 }
1833 f(body, depth)?;
1834 }
1835 HirRelationExpr::Project { input, .. } => {
1836 f(input, depth)?;
1837 }
1838 HirRelationExpr::Map { input, .. } => {
1839 f(input, depth)?;
1840 }
1841 HirRelationExpr::Filter { input, .. } => {
1842 f(input, depth)?;
1843 }
1844 HirRelationExpr::Join { left, right, .. } => {
1845 f(left, depth)?;
1846 f(right, depth + 1)?;
1847 }
1848 HirRelationExpr::Reduce { input, .. } => {
1849 f(input, depth)?;
1850 }
1851 HirRelationExpr::Distinct { input } => {
1852 f(input, depth)?;
1853 }
1854 HirRelationExpr::TopK { input, .. } => {
1855 f(input, depth)?;
1856 }
1857 HirRelationExpr::Negate { input } => {
1858 f(input, depth)?;
1859 }
1860 HirRelationExpr::Threshold { input } => {
1861 f(input, depth)?;
1862 }
1863 HirRelationExpr::Union { base, inputs } => {
1864 f(base, depth)?;
1865 for input in inputs {
1866 f(input, depth)?;
1867 }
1868 }
1869 }
1870 Ok(())
1871 }
1872
1873 #[deprecated = "Use `Visit::visit_mut_post` instead."]
1874 pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1875 where
1876 F: FnMut(&mut Self, usize),
1877 {
1878 #[allow(deprecated)]
1879 let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1880 depth: usize|
1881 -> Result<(), ()> {
1882 f(e, depth);
1883 Ok(())
1884 });
1885 }
1886
1887 #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1888 pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1889 where
1890 F: FnMut(&mut Self, usize) -> Result<(), E>,
1891 {
1892 #[allow(deprecated)]
1893 self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1894 e.visit_mut_fallible(depth, f)
1895 })?;
1896 f(self, depth)
1897 }
1898
1899 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1900 pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1901 where
1902 F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1903 {
1904 match self {
1905 HirRelationExpr::Constant { .. }
1906 | HirRelationExpr::Get { .. }
1907 | HirRelationExpr::CallTable { .. } => (),
1908 HirRelationExpr::Let { body, value, .. } => {
1909 f(value, depth)?;
1910 f(body, depth)?;
1911 }
1912 HirRelationExpr::LetRec {
1913 limit: _,
1914 bindings,
1915 body,
1916 } => {
1917 for (_, _, value, _) in bindings.iter_mut() {
1918 f(value, depth)?;
1919 }
1920 f(body, depth)?;
1921 }
1922 HirRelationExpr::Project { input, .. } => {
1923 f(input, depth)?;
1924 }
1925 HirRelationExpr::Map { input, .. } => {
1926 f(input, depth)?;
1927 }
1928 HirRelationExpr::Filter { input, .. } => {
1929 f(input, depth)?;
1930 }
1931 HirRelationExpr::Join { left, right, .. } => {
1932 f(left, depth)?;
1933 f(right, depth + 1)?;
1934 }
1935 HirRelationExpr::Reduce { input, .. } => {
1936 f(input, depth)?;
1937 }
1938 HirRelationExpr::Distinct { input } => {
1939 f(input, depth)?;
1940 }
1941 HirRelationExpr::TopK { input, .. } => {
1942 f(input, depth)?;
1943 }
1944 HirRelationExpr::Negate { input } => {
1945 f(input, depth)?;
1946 }
1947 HirRelationExpr::Threshold { input } => {
1948 f(input, depth)?;
1949 }
1950 HirRelationExpr::Union { base, inputs } => {
1951 f(base, depth)?;
1952 for input in inputs {
1953 f(input, depth)?;
1954 }
1955 }
1956 }
1957 Ok(())
1958 }
1959
1960 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1961 pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1967 where
1968 F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1969 {
1970 #[allow(deprecated)]
1971 self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1972 depth: usize|
1973 -> Result<(), E> {
1974 match e {
1975 HirRelationExpr::Join { on, .. } => {
1976 f(on, depth)?;
1977 }
1978 HirRelationExpr::Map { scalars, .. } => {
1979 for scalar in scalars {
1980 f(scalar, depth)?;
1981 }
1982 }
1983 HirRelationExpr::CallTable { exprs, .. } => {
1984 for expr in exprs {
1985 f(expr, depth)?;
1986 }
1987 }
1988 HirRelationExpr::Filter { predicates, .. } => {
1989 for predicate in predicates {
1990 f(predicate, depth)?;
1991 }
1992 }
1993 HirRelationExpr::Reduce { aggregates, .. } => {
1994 for aggregate in aggregates {
1995 f(&aggregate.expr, depth)?;
1996 }
1997 }
1998 HirRelationExpr::TopK { limit, offset, .. } => {
1999 if let Some(limit) = limit {
2000 f(limit, depth)?;
2001 }
2002 f(offset, depth)?;
2003 }
2004 HirRelationExpr::Union { .. }
2005 | HirRelationExpr::Let { .. }
2006 | HirRelationExpr::LetRec { .. }
2007 | HirRelationExpr::Project { .. }
2008 | HirRelationExpr::Distinct { .. }
2009 | HirRelationExpr::Negate { .. }
2010 | HirRelationExpr::Threshold { .. }
2011 | HirRelationExpr::Constant { .. }
2012 | HirRelationExpr::Get { .. } => (),
2013 }
2014 Ok(())
2015 })
2016 }
2017
2018 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
2019 pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2021 where
2022 F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2023 {
2024 #[allow(deprecated)]
2025 self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2026 depth: usize|
2027 -> Result<(), E> {
2028 match e {
2029 HirRelationExpr::Join { on, .. } => {
2030 f(on, depth)?;
2031 }
2032 HirRelationExpr::Map { scalars, .. } => {
2033 for scalar in scalars.iter_mut() {
2034 f(scalar, depth)?;
2035 }
2036 }
2037 HirRelationExpr::CallTable { exprs, .. } => {
2038 for expr in exprs.iter_mut() {
2039 f(expr, depth)?;
2040 }
2041 }
2042 HirRelationExpr::Filter { predicates, .. } => {
2043 for predicate in predicates.iter_mut() {
2044 f(predicate, depth)?;
2045 }
2046 }
2047 HirRelationExpr::Reduce { aggregates, .. } => {
2048 for aggregate in aggregates.iter_mut() {
2049 f(&mut aggregate.expr, depth)?;
2050 }
2051 }
2052 HirRelationExpr::TopK { limit, offset, .. } => {
2053 if let Some(limit) = limit {
2054 f(limit, depth)?;
2055 }
2056 f(offset, depth)?;
2057 }
2058 HirRelationExpr::Union { .. }
2059 | HirRelationExpr::Let { .. }
2060 | HirRelationExpr::LetRec { .. }
2061 | HirRelationExpr::Project { .. }
2062 | HirRelationExpr::Distinct { .. }
2063 | HirRelationExpr::Negate { .. }
2064 | HirRelationExpr::Threshold { .. }
2065 | HirRelationExpr::Constant { .. }
2066 | HirRelationExpr::Get { .. } => (),
2067 }
2068 Ok(())
2069 })
2070 }
2071
2072 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2073 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2079 where
2080 F: FnMut(usize, &ColumnRef),
2081 {
2082 #[allow(deprecated)]
2083 let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2084 depth: usize|
2085 -> Result<(), ()> {
2086 e.visit_columns(depth, f);
2087 Ok(())
2088 });
2089 }
2090
2091 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2092 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2094 where
2095 F: FnMut(usize, &mut ColumnRef),
2096 {
2097 #[allow(deprecated)]
2098 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2099 depth: usize|
2100 -> Result<(), ()> {
2101 e.visit_columns_mut(depth, f);
2102 Ok(())
2103 });
2104 }
2105
2106 pub fn bind_parameters(
2109 &mut self,
2110 scx: &StatementContext,
2111 lifetime: QueryLifetime,
2112 params: &Params,
2113 ) -> Result<(), PlanError> {
2114 #[allow(deprecated)]
2115 self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2116 e.bind_parameters(scx, lifetime, params)
2117 })
2118 }
2119
2120 pub fn contains_parameters(&self) -> Result<bool, PlanError> {
2121 let mut contains_parameters = false;
2122 #[allow(deprecated)]
2123 self.visit_scalar_expressions(0, &mut |e: &HirScalarExpr, _: usize| {
2124 if e.contains_parameters() {
2125 contains_parameters = true;
2126 }
2127 Ok::<(), PlanError>(())
2128 })?;
2129 Ok(contains_parameters)
2130 }
2131
2132 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2134 #[allow(deprecated)]
2135 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2136 depth: usize|
2137 -> Result<(), ()> {
2138 e.splice_parameters(params, depth);
2139 Ok(())
2140 });
2141 }
2142
2143 pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
2145 let rows = rows
2146 .into_iter()
2147 .map(move |datums| Row::pack_slice(&datums))
2148 .collect();
2149 HirRelationExpr::Constant { rows, typ }
2150 }
2151
2152 pub fn finish_maintained(
2158 &mut self,
2159 finishing: &mut RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2160 group_size_hints: GroupSizeHints,
2161 ) {
2162 if !HirRelationExpr::is_trivial_row_set_finishing_hir(finishing, self.arity()) {
2163 let old_finishing = mem::replace(
2164 finishing,
2165 HirRelationExpr::trivial_row_set_finishing_hir(finishing.project.len()),
2166 );
2167 *self = HirRelationExpr::top_k(
2168 std::mem::replace(
2169 self,
2170 HirRelationExpr::Constant {
2171 rows: vec![],
2172 typ: SqlRelationType::new(Vec::new()),
2173 },
2174 ),
2175 vec![],
2176 old_finishing.order_by,
2177 old_finishing.limit,
2178 old_finishing.offset,
2179 group_size_hints.limit_input_group_size,
2180 )
2181 .project(old_finishing.project);
2182 }
2183 }
2184
2185 pub fn trivial_row_set_finishing_hir(
2190 arity: usize,
2191 ) -> RowSetFinishing<HirScalarExpr, HirScalarExpr> {
2192 RowSetFinishing {
2193 order_by: Vec::new(),
2194 limit: None,
2195 offset: HirScalarExpr::literal(Datum::Int64(0), SqlScalarType::Int64),
2196 project: (0..arity).collect(),
2197 }
2198 }
2199
2200 pub fn is_trivial_row_set_finishing_hir(
2205 rsf: &RowSetFinishing<HirScalarExpr, HirScalarExpr>,
2206 arity: usize,
2207 ) -> bool {
2208 rsf.limit.is_none()
2209 && rsf.order_by.is_empty()
2210 && rsf
2211 .offset
2212 .clone()
2213 .try_into_literal_int64()
2214 .is_ok_and(|o| o == 0)
2215 && rsf.project.iter().copied().eq(0..arity)
2216 }
2217
2218 pub fn could_run_expensive_function(&self) -> bool {
2227 let mut result = false;
2228 if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2229 use HirRelationExpr::*;
2230 use HirScalarExpr::*;
2231
2232 self.visit_children(|scalar: &HirScalarExpr| {
2233 if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2234 result |= match scalar {
2235 Column(..)
2236 | Literal(..)
2237 | CallUnmaterializable(..)
2238 | If { .. }
2239 | Parameter(..)
2240 | Select(..)
2241 | Exists(..) => false,
2242 CallUnary { .. }
2244 | CallBinary { .. }
2245 | CallVariadic { .. }
2246 | Windowing(..) => true,
2247 };
2248 }) {
2249 result = true;
2251 }
2252 });
2253
2254 result |= matches!(e, CallTable { .. } | Reduce { .. });
2257 }) {
2258 result = true;
2260 }
2261
2262 result
2263 }
2264
2265 pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2267 let mut contains = false;
2268 self.visit_post(&mut |expr| {
2269 expr.visit_children(|expr: &HirScalarExpr| {
2270 contains = contains || expr.contains_temporal()
2271 })
2272 })?;
2273 Ok(contains)
2274 }
2275}
2276
2277impl CollectionPlan for HirRelationExpr {
2278 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2281 if let Self::Get {
2282 id: Id::Global(id), ..
2283 } = self
2284 {
2285 out.insert(*id);
2286 }
2287 self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2288 }
2289}
2290
2291impl VisitChildren<Self> for HirRelationExpr {
2292 fn visit_children<F>(&self, mut f: F)
2293 where
2294 F: FnMut(&Self),
2295 {
2296 VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2300 #[allow(deprecated)]
2301 Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2302 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2303 f(expr.as_ref())
2304 }
2305 _ => (),
2306 });
2307 });
2308
2309 use HirRelationExpr::*;
2310 match self {
2311 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2312 Let {
2313 name: _,
2314 id: _,
2315 value,
2316 body,
2317 } => {
2318 f(value);
2319 f(body);
2320 }
2321 LetRec {
2322 limit: _,
2323 bindings,
2324 body,
2325 } => {
2326 for (_, _, value, _) in bindings.iter() {
2327 f(value);
2328 }
2329 f(body);
2330 }
2331 Project { input, outputs: _ } => f(input),
2332 Map { input, scalars: _ } => {
2333 f(input);
2334 }
2335 CallTable { func: _, exprs: _ } => (),
2336 Filter {
2337 input,
2338 predicates: _,
2339 } => {
2340 f(input);
2341 }
2342 Join {
2343 left,
2344 right,
2345 on: _,
2346 kind: _,
2347 } => {
2348 f(left);
2349 f(right);
2350 }
2351 Reduce {
2352 input,
2353 group_key: _,
2354 aggregates: _,
2355 expected_group_size: _,
2356 } => {
2357 f(input);
2358 }
2359 Distinct { input }
2360 | TopK {
2361 input,
2362 group_key: _,
2363 order_key: _,
2364 limit: _,
2365 offset: _,
2366 expected_group_size: _,
2367 }
2368 | Negate { input }
2369 | Threshold { input } => {
2370 f(input);
2371 }
2372 Union { base, inputs } => {
2373 f(base);
2374 for input in inputs {
2375 f(input);
2376 }
2377 }
2378 }
2379 }
2380
2381 fn visit_mut_children<F>(&mut self, mut f: F)
2382 where
2383 F: FnMut(&mut Self),
2384 {
2385 VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2389 #[allow(deprecated)]
2390 Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2391 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2392 f(expr.as_mut())
2393 }
2394 _ => (),
2395 });
2396 });
2397
2398 use HirRelationExpr::*;
2399 match self {
2400 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2401 Let {
2402 name: _,
2403 id: _,
2404 value,
2405 body,
2406 } => {
2407 f(value);
2408 f(body);
2409 }
2410 LetRec {
2411 limit: _,
2412 bindings,
2413 body,
2414 } => {
2415 for (_, _, value, _) in bindings.iter_mut() {
2416 f(value);
2417 }
2418 f(body);
2419 }
2420 Project { input, outputs: _ } => f(input),
2421 Map { input, scalars: _ } => {
2422 f(input);
2423 }
2424 CallTable { func: _, exprs: _ } => (),
2425 Filter {
2426 input,
2427 predicates: _,
2428 } => {
2429 f(input);
2430 }
2431 Join {
2432 left,
2433 right,
2434 on: _,
2435 kind: _,
2436 } => {
2437 f(left);
2438 f(right);
2439 }
2440 Reduce {
2441 input,
2442 group_key: _,
2443 aggregates: _,
2444 expected_group_size: _,
2445 } => {
2446 f(input);
2447 }
2448 Distinct { input }
2449 | TopK {
2450 input,
2451 group_key: _,
2452 order_key: _,
2453 limit: _,
2454 offset: _,
2455 expected_group_size: _,
2456 }
2457 | Negate { input }
2458 | Threshold { input } => {
2459 f(input);
2460 }
2461 Union { base, inputs } => {
2462 f(base);
2463 for input in inputs {
2464 f(input);
2465 }
2466 }
2467 }
2468 }
2469
2470 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2471 where
2472 F: FnMut(&Self) -> Result<(), E>,
2473 E: From<RecursionLimitError>,
2474 {
2475 VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2479 Visit::try_visit_post(expr, &mut |expr| match expr {
2480 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2481 f(expr.as_ref())
2482 }
2483 _ => Ok(()),
2484 })
2485 })?;
2486
2487 use HirRelationExpr::*;
2488 match self {
2489 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2490 Let {
2491 name: _,
2492 id: _,
2493 value,
2494 body,
2495 } => {
2496 f(value)?;
2497 f(body)?;
2498 }
2499 LetRec {
2500 limit: _,
2501 bindings,
2502 body,
2503 } => {
2504 for (_, _, value, _) in bindings.iter() {
2505 f(value)?;
2506 }
2507 f(body)?;
2508 }
2509 Project { input, outputs: _ } => f(input)?,
2510 Map { input, scalars: _ } => {
2511 f(input)?;
2512 }
2513 CallTable { func: _, exprs: _ } => (),
2514 Filter {
2515 input,
2516 predicates: _,
2517 } => {
2518 f(input)?;
2519 }
2520 Join {
2521 left,
2522 right,
2523 on: _,
2524 kind: _,
2525 } => {
2526 f(left)?;
2527 f(right)?;
2528 }
2529 Reduce {
2530 input,
2531 group_key: _,
2532 aggregates: _,
2533 expected_group_size: _,
2534 } => {
2535 f(input)?;
2536 }
2537 Distinct { input }
2538 | TopK {
2539 input,
2540 group_key: _,
2541 order_key: _,
2542 limit: _,
2543 offset: _,
2544 expected_group_size: _,
2545 }
2546 | Negate { input }
2547 | Threshold { input } => {
2548 f(input)?;
2549 }
2550 Union { base, inputs } => {
2551 f(base)?;
2552 for input in inputs {
2553 f(input)?;
2554 }
2555 }
2556 }
2557 Ok(())
2558 }
2559
2560 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2561 where
2562 F: FnMut(&mut Self) -> Result<(), E>,
2563 E: From<RecursionLimitError>,
2564 {
2565 VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2569 Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2570 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
2571 f(expr.as_mut())
2572 }
2573 _ => Ok(()),
2574 })
2575 })?;
2576
2577 use HirRelationExpr::*;
2578 match self {
2579 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2580 Let {
2581 name: _,
2582 id: _,
2583 value,
2584 body,
2585 } => {
2586 f(value)?;
2587 f(body)?;
2588 }
2589 LetRec {
2590 limit: _,
2591 bindings,
2592 body,
2593 } => {
2594 for (_, _, value, _) in bindings.iter_mut() {
2595 f(value)?;
2596 }
2597 f(body)?;
2598 }
2599 Project { input, outputs: _ } => f(input)?,
2600 Map { input, scalars: _ } => {
2601 f(input)?;
2602 }
2603 CallTable { func: _, exprs: _ } => (),
2604 Filter {
2605 input,
2606 predicates: _,
2607 } => {
2608 f(input)?;
2609 }
2610 Join {
2611 left,
2612 right,
2613 on: _,
2614 kind: _,
2615 } => {
2616 f(left)?;
2617 f(right)?;
2618 }
2619 Reduce {
2620 input,
2621 group_key: _,
2622 aggregates: _,
2623 expected_group_size: _,
2624 } => {
2625 f(input)?;
2626 }
2627 Distinct { input }
2628 | TopK {
2629 input,
2630 group_key: _,
2631 order_key: _,
2632 limit: _,
2633 offset: _,
2634 expected_group_size: _,
2635 }
2636 | Negate { input }
2637 | Threshold { input } => {
2638 f(input)?;
2639 }
2640 Union { base, inputs } => {
2641 f(base)?;
2642 for input in inputs {
2643 f(input)?;
2644 }
2645 }
2646 }
2647 Ok(())
2648 }
2649}
2650
2651impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2652 fn visit_children<F>(&self, mut f: F)
2653 where
2654 F: FnMut(&HirScalarExpr),
2655 {
2656 use HirRelationExpr::*;
2657 match self {
2658 Constant { rows: _, typ: _ }
2659 | Get { id: _, typ: _ }
2660 | Let {
2661 name: _,
2662 id: _,
2663 value: _,
2664 body: _,
2665 }
2666 | LetRec {
2667 limit: _,
2668 bindings: _,
2669 body: _,
2670 }
2671 | Project {
2672 input: _,
2673 outputs: _,
2674 } => (),
2675 Map { input: _, scalars } => {
2676 for scalar in scalars {
2677 f(scalar);
2678 }
2679 }
2680 CallTable { func: _, exprs } => {
2681 for expr in exprs {
2682 f(expr);
2683 }
2684 }
2685 Filter {
2686 input: _,
2687 predicates,
2688 } => {
2689 for predicate in predicates {
2690 f(predicate);
2691 }
2692 }
2693 Join {
2694 left: _,
2695 right: _,
2696 on,
2697 kind: _,
2698 } => f(on),
2699 Reduce {
2700 input: _,
2701 group_key: _,
2702 aggregates,
2703 expected_group_size: _,
2704 } => {
2705 for aggregate in aggregates {
2706 f(aggregate.expr.as_ref());
2707 }
2708 }
2709 TopK {
2710 input: _,
2711 group_key: _,
2712 order_key: _,
2713 limit,
2714 offset,
2715 expected_group_size: _,
2716 } => {
2717 if let Some(limit) = limit {
2718 f(limit)
2719 }
2720 f(offset)
2721 }
2722 Distinct { input: _ }
2723 | Negate { input: _ }
2724 | Threshold { input: _ }
2725 | Union { base: _, inputs: _ } => (),
2726 }
2727 }
2728
2729 fn visit_mut_children<F>(&mut self, mut f: F)
2730 where
2731 F: FnMut(&mut HirScalarExpr),
2732 {
2733 use HirRelationExpr::*;
2734 match self {
2735 Constant { rows: _, typ: _ }
2736 | Get { id: _, typ: _ }
2737 | Let {
2738 name: _,
2739 id: _,
2740 value: _,
2741 body: _,
2742 }
2743 | LetRec {
2744 limit: _,
2745 bindings: _,
2746 body: _,
2747 }
2748 | Project {
2749 input: _,
2750 outputs: _,
2751 } => (),
2752 Map { input: _, scalars } => {
2753 for scalar in scalars {
2754 f(scalar);
2755 }
2756 }
2757 CallTable { func: _, exprs } => {
2758 for expr in exprs {
2759 f(expr);
2760 }
2761 }
2762 Filter {
2763 input: _,
2764 predicates,
2765 } => {
2766 for predicate in predicates {
2767 f(predicate);
2768 }
2769 }
2770 Join {
2771 left: _,
2772 right: _,
2773 on,
2774 kind: _,
2775 } => f(on),
2776 Reduce {
2777 input: _,
2778 group_key: _,
2779 aggregates,
2780 expected_group_size: _,
2781 } => {
2782 for aggregate in aggregates {
2783 f(aggregate.expr.as_mut());
2784 }
2785 }
2786 TopK {
2787 input: _,
2788 group_key: _,
2789 order_key: _,
2790 limit,
2791 offset,
2792 expected_group_size: _,
2793 } => {
2794 if let Some(limit) = limit {
2795 f(limit)
2796 }
2797 f(offset)
2798 }
2799 Distinct { input: _ }
2800 | Negate { input: _ }
2801 | Threshold { input: _ }
2802 | Union { base: _, inputs: _ } => (),
2803 }
2804 }
2805
2806 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2807 where
2808 F: FnMut(&HirScalarExpr) -> Result<(), E>,
2809 E: From<RecursionLimitError>,
2810 {
2811 use HirRelationExpr::*;
2812 match self {
2813 Constant { rows: _, typ: _ }
2814 | Get { id: _, typ: _ }
2815 | Let {
2816 name: _,
2817 id: _,
2818 value: _,
2819 body: _,
2820 }
2821 | LetRec {
2822 limit: _,
2823 bindings: _,
2824 body: _,
2825 }
2826 | Project {
2827 input: _,
2828 outputs: _,
2829 } => (),
2830 Map { input: _, scalars } => {
2831 for scalar in scalars {
2832 f(scalar)?;
2833 }
2834 }
2835 CallTable { func: _, exprs } => {
2836 for expr in exprs {
2837 f(expr)?;
2838 }
2839 }
2840 Filter {
2841 input: _,
2842 predicates,
2843 } => {
2844 for predicate in predicates {
2845 f(predicate)?;
2846 }
2847 }
2848 Join {
2849 left: _,
2850 right: _,
2851 on,
2852 kind: _,
2853 } => f(on)?,
2854 Reduce {
2855 input: _,
2856 group_key: _,
2857 aggregates,
2858 expected_group_size: _,
2859 } => {
2860 for aggregate in aggregates {
2861 f(aggregate.expr.as_ref())?;
2862 }
2863 }
2864 TopK {
2865 input: _,
2866 group_key: _,
2867 order_key: _,
2868 limit,
2869 offset,
2870 expected_group_size: _,
2871 } => {
2872 if let Some(limit) = limit {
2873 f(limit)?
2874 }
2875 f(offset)?
2876 }
2877 Distinct { input: _ }
2878 | Negate { input: _ }
2879 | Threshold { input: _ }
2880 | Union { base: _, inputs: _ } => (),
2881 }
2882 Ok(())
2883 }
2884
2885 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2886 where
2887 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2888 E: From<RecursionLimitError>,
2889 {
2890 use HirRelationExpr::*;
2891 match self {
2892 Constant { rows: _, typ: _ }
2893 | Get { id: _, typ: _ }
2894 | Let {
2895 name: _,
2896 id: _,
2897 value: _,
2898 body: _,
2899 }
2900 | LetRec {
2901 limit: _,
2902 bindings: _,
2903 body: _,
2904 }
2905 | Project {
2906 input: _,
2907 outputs: _,
2908 } => (),
2909 Map { input: _, scalars } => {
2910 for scalar in scalars {
2911 f(scalar)?;
2912 }
2913 }
2914 CallTable { func: _, exprs } => {
2915 for expr in exprs {
2916 f(expr)?;
2917 }
2918 }
2919 Filter {
2920 input: _,
2921 predicates,
2922 } => {
2923 for predicate in predicates {
2924 f(predicate)?;
2925 }
2926 }
2927 Join {
2928 left: _,
2929 right: _,
2930 on,
2931 kind: _,
2932 } => f(on)?,
2933 Reduce {
2934 input: _,
2935 group_key: _,
2936 aggregates,
2937 expected_group_size: _,
2938 } => {
2939 for aggregate in aggregates {
2940 f(aggregate.expr.as_mut())?;
2941 }
2942 }
2943 TopK {
2944 input: _,
2945 group_key: _,
2946 order_key: _,
2947 limit,
2948 offset,
2949 expected_group_size: _,
2950 } => {
2951 if let Some(limit) = limit {
2952 f(limit)?
2953 }
2954 f(offset)?
2955 }
2956 Distinct { input: _ }
2957 | Negate { input: _ }
2958 | Threshold { input: _ }
2959 | Union { base: _, inputs: _ } => (),
2960 }
2961 Ok(())
2962 }
2963}
2964
2965impl HirScalarExpr {
2966 pub fn name(&self) -> Option<Arc<str>> {
2967 use HirScalarExpr::*;
2968 match self {
2969 Column(_, name)
2970 | Parameter(_, name)
2971 | Literal(_, _, name)
2972 | CallUnmaterializable(_, name)
2973 | CallUnary { name, .. }
2974 | CallBinary { name, .. }
2975 | CallVariadic { name, .. }
2976 | If { name, .. }
2977 | Exists(_, name)
2978 | Select(_, name)
2979 | Windowing(_, name) => name.0.clone(),
2980 }
2981 }
2982
2983 pub fn bind_parameters(
2986 &mut self,
2987 scx: &StatementContext,
2988 lifetime: QueryLifetime,
2989 params: &Params,
2990 ) -> Result<(), PlanError> {
2991 #[allow(deprecated)]
2992 self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
2993 if let HirScalarExpr::Parameter(n, name) = e {
2994 let datum = match params.datums.iter().nth(*n - 1) {
2995 None => return Err(PlanError::UnknownParameter(*n)),
2996 Some(datum) => datum,
2997 };
2998 let scalar_type = ¶ms.execute_types[*n - 1];
2999 let row = Row::pack([datum]);
3000 let column_type = scalar_type.clone().nullable(datum.is_null());
3001
3002 let name = if let Some(name) = &name.0 {
3003 Some(Arc::clone(name))
3004 } else {
3005 Some(Arc::from(format!("${n}")))
3006 };
3007
3008 let qcx = QueryContext::root(scx, lifetime);
3009 let ecx = execute_expr_context(&qcx);
3010
3011 *e = plan_cast(
3012 &ecx,
3013 *EXECUTE_CAST_CONTEXT,
3014 HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3015 ¶ms.expected_types[*n - 1],
3016 )
3017 .expect("checked in plan_params");
3018 }
3019 Ok(())
3020 })
3021 }
3022
3023 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
3034 #[allow(deprecated)]
3035 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3036 e: &mut HirScalarExpr|
3037 -> Result<(), ()> {
3038 if let HirScalarExpr::Parameter(i, _name) = e {
3039 *e = params[*i - 1].clone();
3040 e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
3043 if col.level >= d {
3044 col.level += depth
3045 }
3046 });
3047 }
3048 Ok(())
3049 });
3050 }
3051
3052 pub fn contains_temporal(&self) -> bool {
3054 let mut contains = false;
3055 #[allow(deprecated)]
3056 self.visit_post_nolimit(&mut |e| {
3057 if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e {
3058 contains = true;
3059 }
3060 });
3061 contains
3062 }
3063
3064 pub fn column(index: usize) -> HirScalarExpr {
3068 HirScalarExpr::Column(
3069 ColumnRef {
3070 level: 0,
3071 column: index,
3072 },
3073 TreatAsEqual(None),
3074 )
3075 }
3076
3077 pub fn unnamed_column(cr: ColumnRef) -> HirScalarExpr {
3079 HirScalarExpr::Column(cr, TreatAsEqual(None))
3080 }
3081
3082 pub fn named_column(cr: ColumnRef, name: Arc<str>) -> HirScalarExpr {
3085 HirScalarExpr::Column(cr, TreatAsEqual(Some(name)))
3086 }
3087
3088 pub fn parameter(n: usize) -> HirScalarExpr {
3089 HirScalarExpr::Parameter(n, TreatAsEqual(None))
3090 }
3091
3092 pub fn literal(datum: Datum, scalar_type: SqlScalarType) -> HirScalarExpr {
3093 let col_type = scalar_type.nullable(datum.is_null());
3094 soft_assert_or_log!(datum.is_instance_of_sql(&col_type), "type is correct");
3095 let row = Row::pack([datum]);
3096 HirScalarExpr::Literal(row, col_type, TreatAsEqual(None))
3097 }
3098
3099 pub fn literal_true() -> HirScalarExpr {
3100 HirScalarExpr::literal(Datum::True, SqlScalarType::Bool)
3101 }
3102
3103 pub fn literal_false() -> HirScalarExpr {
3104 HirScalarExpr::literal(Datum::False, SqlScalarType::Bool)
3105 }
3106
3107 pub fn literal_null(scalar_type: SqlScalarType) -> HirScalarExpr {
3108 HirScalarExpr::literal(Datum::Null, scalar_type)
3109 }
3110
3111 pub fn literal_1d_array(
3112 datums: Vec<Datum>,
3113 element_scalar_type: SqlScalarType,
3114 ) -> Result<HirScalarExpr, PlanError> {
3115 let scalar_type = match element_scalar_type {
3116 SqlScalarType::Array(_) => {
3117 sql_bail!("cannot build array from array type");
3118 }
3119 typ => SqlScalarType::Array(Box::new(typ)).nullable(false),
3120 };
3121
3122 let mut row = Row::default();
3123 row.packer()
3124 .try_push_array(
3125 &[ArrayDimension {
3126 lower_bound: 1,
3127 length: datums.len(),
3128 }],
3129 datums,
3130 )
3131 .expect("array constructed to be valid");
3132
3133 Ok(HirScalarExpr::Literal(row, scalar_type, TreatAsEqual(None)))
3134 }
3135
3136 pub fn as_literal(&self) -> Option<Datum<'_>> {
3137 if let HirScalarExpr::Literal(row, _column_type, _name) = self {
3138 Some(row.unpack_first())
3139 } else {
3140 None
3141 }
3142 }
3143
3144 pub fn is_literal_true(&self) -> bool {
3145 Some(Datum::True) == self.as_literal()
3146 }
3147
3148 pub fn is_literal_false(&self) -> bool {
3149 Some(Datum::False) == self.as_literal()
3150 }
3151
3152 pub fn is_literal_null(&self) -> bool {
3153 Some(Datum::Null) == self.as_literal()
3154 }
3155
3156 pub fn is_constant(&self) -> bool {
3159 let mut worklist = vec![self];
3160 while let Some(expr) = worklist.pop() {
3161 match expr {
3162 Self::Literal(..) => {
3163 }
3165 Self::CallUnary { expr, .. } => {
3166 worklist.push(expr);
3167 }
3168 Self::CallBinary {
3169 func: _,
3170 expr1,
3171 expr2,
3172 name: _,
3173 } => {
3174 worklist.push(expr1);
3175 worklist.push(expr2);
3176 }
3177 Self::CallVariadic {
3178 func: _,
3179 exprs,
3180 name: _,
3181 } => {
3182 worklist.extend(exprs.iter());
3183 }
3184 Self::If {
3186 cond,
3187 then,
3188 els,
3189 name: _,
3190 } => {
3191 worklist.push(cond);
3192 worklist.push(then);
3193 worklist.push(els);
3194 }
3195 _ => {
3196 return false; }
3198 }
3199 }
3200 true
3201 }
3202
3203 pub fn call_unary(self, func: UnaryFunc) -> Self {
3204 HirScalarExpr::CallUnary {
3205 func,
3206 expr: Box::new(self),
3207 name: NameMetadata::default(),
3208 }
3209 }
3210
3211 pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
3212 HirScalarExpr::CallBinary {
3213 func,
3214 expr1: Box::new(self),
3215 expr2: Box::new(other),
3216 name: NameMetadata::default(),
3217 }
3218 }
3219
3220 pub fn call_unmaterializable(func: UnmaterializableFunc) -> Self {
3221 HirScalarExpr::CallUnmaterializable(func, NameMetadata::default())
3222 }
3223
3224 pub fn call_variadic(func: VariadicFunc, exprs: Vec<Self>) -> Self {
3225 HirScalarExpr::CallVariadic {
3226 func,
3227 exprs,
3228 name: NameMetadata::default(),
3229 }
3230 }
3231
3232 pub fn if_then_else(cond: Self, then: Self, els: Self) -> Self {
3233 HirScalarExpr::If {
3234 cond: Box::new(cond),
3235 then: Box::new(then),
3236 els: Box::new(els),
3237 name: NameMetadata::default(),
3238 }
3239 }
3240
3241 pub fn windowing(expr: WindowExpr) -> Self {
3242 HirScalarExpr::Windowing(expr, TreatAsEqual(None))
3243 }
3244
3245 pub fn or(self, other: Self) -> Self {
3246 HirScalarExpr::call_variadic(VariadicFunc::Or, vec![self, other])
3247 }
3248
3249 pub fn and(self, other: Self) -> Self {
3250 HirScalarExpr::call_variadic(VariadicFunc::And, vec![self, other])
3251 }
3252
3253 pub fn not(self) -> Self {
3254 self.call_unary(UnaryFunc::Not(func::Not))
3255 }
3256
3257 pub fn call_is_null(self) -> Self {
3258 self.call_unary(UnaryFunc::IsNull(func::IsNull))
3259 }
3260
3261 pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3263 match args.len() {
3264 0 => HirScalarExpr::literal_true(), 1 => args.swap_remove(0),
3266 _ => HirScalarExpr::call_variadic(VariadicFunc::And, args),
3267 }
3268 }
3269
3270 pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3272 match args.len() {
3273 0 => HirScalarExpr::literal_false(), 1 => args.swap_remove(0),
3275 _ => HirScalarExpr::call_variadic(VariadicFunc::Or, args),
3276 }
3277 }
3278
3279 pub fn take(&mut self) -> Self {
3280 mem::replace(self, HirScalarExpr::literal_null(SqlScalarType::String))
3281 }
3282
3283 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3284 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3290 where
3291 F: FnMut(usize, &ColumnRef),
3292 {
3293 #[allow(deprecated)]
3294 let _ = self.visit_recursively(depth, &mut |depth: usize,
3295 e: &HirScalarExpr|
3296 -> Result<(), ()> {
3297 if let HirScalarExpr::Column(col, _name) = e {
3298 f(depth, col)
3299 }
3300 Ok(())
3301 });
3302 }
3303
3304 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3305 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3307 where
3308 F: FnMut(usize, &mut ColumnRef),
3309 {
3310 #[allow(deprecated)]
3311 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3312 e: &mut HirScalarExpr|
3313 -> Result<(), ()> {
3314 if let HirScalarExpr::Column(col, _name) = e {
3315 f(depth, col)
3316 }
3317 Ok(())
3318 });
3319 }
3320
3321 pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3327 where
3328 F: FnMut(usize),
3329 {
3330 #[allow(deprecated)]
3331 let _ = self.visit_recursively(0, &mut |depth: usize,
3332 e: &HirScalarExpr|
3333 -> Result<(), ()> {
3334 if let HirScalarExpr::Column(col, _name) = e {
3335 if col.level == depth {
3336 f(col.column)
3337 }
3338 }
3339 Ok(())
3340 });
3341 }
3342
3343 pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3345 where
3346 F: FnMut(&mut usize),
3347 {
3348 #[allow(deprecated)]
3349 let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3350 e: &mut HirScalarExpr|
3351 -> Result<(), ()> {
3352 if let HirScalarExpr::Column(col, _name) = e {
3353 if col.level == depth {
3354 f(&mut col.column)
3355 }
3356 }
3357 Ok(())
3358 });
3359 }
3360
3361 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3362 pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3366 where
3367 F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3368 {
3369 match self {
3370 HirScalarExpr::Literal(..)
3371 | HirScalarExpr::Parameter(..)
3372 | HirScalarExpr::CallUnmaterializable(..)
3373 | HirScalarExpr::Column(..) => (),
3374 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3375 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3376 expr1.visit_recursively(depth, f)?;
3377 expr2.visit_recursively(depth, f)?;
3378 }
3379 HirScalarExpr::CallVariadic { exprs, .. } => {
3380 for expr in exprs {
3381 expr.visit_recursively(depth, f)?;
3382 }
3383 }
3384 HirScalarExpr::If {
3385 cond,
3386 then,
3387 els,
3388 name: _,
3389 } => {
3390 cond.visit_recursively(depth, f)?;
3391 then.visit_recursively(depth, f)?;
3392 els.visit_recursively(depth, f)?;
3393 }
3394 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3395 #[allow(deprecated)]
3396 expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3397 e.visit_recursively(depth, f)
3398 })?;
3399 }
3400 HirScalarExpr::Windowing(expr, _name) => {
3401 expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3402 }
3403 }
3404 f(depth, self)
3405 }
3406
3407 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3408 pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3410 where
3411 F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3412 {
3413 match self {
3414 HirScalarExpr::Literal(..)
3415 | HirScalarExpr::Parameter(..)
3416 | HirScalarExpr::CallUnmaterializable(..)
3417 | HirScalarExpr::Column(..) => (),
3418 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3419 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3420 expr1.visit_recursively_mut(depth, f)?;
3421 expr2.visit_recursively_mut(depth, f)?;
3422 }
3423 HirScalarExpr::CallVariadic { exprs, .. } => {
3424 for expr in exprs {
3425 expr.visit_recursively_mut(depth, f)?;
3426 }
3427 }
3428 HirScalarExpr::If {
3429 cond,
3430 then,
3431 els,
3432 name: _,
3433 } => {
3434 cond.visit_recursively_mut(depth, f)?;
3435 then.visit_recursively_mut(depth, f)?;
3436 els.visit_recursively_mut(depth, f)?;
3437 }
3438 HirScalarExpr::Exists(expr, _name) | HirScalarExpr::Select(expr, _name) => {
3439 #[allow(deprecated)]
3440 expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3441 e.visit_recursively_mut(depth, f)
3442 })?;
3443 }
3444 HirScalarExpr::Windowing(expr, _name) => {
3445 expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3446 }
3447 }
3448 f(depth, self)
3449 }
3450
3451 fn simplify_to_literal(self) -> Option<Row> {
3460 let mut expr = self.lower_uncorrelated().ok()?;
3461 expr.reduce(&[]);
3462 match expr {
3463 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3464 _ => None,
3465 }
3466 }
3467
3468 fn simplify_to_literal_with_result(self) -> Result<Row, PlanError> {
3481 let mut expr = self.lower_uncorrelated().map_err(|err| {
3482 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes())
3483 })?;
3484 expr.reduce(&[]);
3485 match expr {
3486 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Ok(row),
3487 mz_expr::MirScalarExpr::Literal(Err(err), _) => Err(
3488 PlanError::ConstantExpressionSimplificationFailed(err.to_string_with_causes()),
3489 ),
3490 _ => Err(PlanError::ConstantExpressionSimplificationFailed(
3491 "Not a constant".to_string(),
3492 )),
3493 }
3494 }
3495
3496 pub fn into_literal_int64(self) -> Option<i64> {
3505 self.simplify_to_literal().and_then(|row| {
3506 let datum = row.unpack_first();
3507 if datum.is_null() {
3508 None
3509 } else {
3510 Some(datum.unwrap_int64())
3511 }
3512 })
3513 }
3514
3515 pub fn into_literal_string(self) -> Option<String> {
3524 self.simplify_to_literal().and_then(|row| {
3525 let datum = row.unpack_first();
3526 if datum.is_null() {
3527 None
3528 } else {
3529 Some(datum.unwrap_str().to_owned())
3530 }
3531 })
3532 }
3533
3534 pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3546 self.simplify_to_literal().and_then(|row| {
3547 let datum = row.unpack_first();
3548 if datum.is_null() {
3549 None
3550 } else {
3551 Some(datum.unwrap_mz_timestamp())
3552 }
3553 })
3554 }
3555
3556 pub fn try_into_literal_int64(self) -> Result<i64, PlanError> {
3568 if !self.is_constant() {
3574 return Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3575 "Expected a constant expression, got {}",
3576 self
3577 )));
3578 }
3579 self.clone()
3580 .simplify_to_literal_with_result()
3581 .and_then(|row| {
3582 let datum = row.unpack_first();
3583 if datum.is_null() {
3584 Err(PlanError::ConstantExpressionSimplificationFailed(format!(
3585 "Expected an expression that evaluates to a non-null value, got {}",
3586 self
3587 )))
3588 } else {
3589 Ok(datum.unwrap_int64())
3590 }
3591 })
3592 }
3593
3594 pub fn contains_parameters(&self) -> bool {
3595 let mut contains_parameters = false;
3596 #[allow(deprecated)]
3597 let _ = self.visit_recursively(0, &mut |_depth: usize,
3598 expr: &HirScalarExpr|
3599 -> Result<(), ()> {
3600 if let HirScalarExpr::Parameter(..) = expr {
3601 contains_parameters = true;
3602 }
3603 Ok(())
3604 });
3605 contains_parameters
3606 }
3607}
3608
3609impl VisitChildren<Self> for HirScalarExpr {
3610 fn visit_children<F>(&self, mut f: F)
3611 where
3612 F: FnMut(&Self),
3613 {
3614 use HirScalarExpr::*;
3615 match self {
3616 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3617 CallUnary { expr, .. } => f(expr),
3618 CallBinary { expr1, expr2, .. } => {
3619 f(expr1);
3620 f(expr2);
3621 }
3622 CallVariadic { exprs, .. } => {
3623 for expr in exprs {
3624 f(expr);
3625 }
3626 }
3627 If {
3628 cond,
3629 then,
3630 els,
3631 name: _,
3632 } => {
3633 f(cond);
3634 f(then);
3635 f(els);
3636 }
3637 Exists(..) | Select(..) => (),
3638 Windowing(expr, _name) => expr.visit_children(f),
3639 }
3640 }
3641
3642 fn visit_mut_children<F>(&mut self, mut f: F)
3643 where
3644 F: FnMut(&mut Self),
3645 {
3646 use HirScalarExpr::*;
3647 match self {
3648 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3649 CallUnary { expr, .. } => f(expr),
3650 CallBinary { expr1, expr2, .. } => {
3651 f(expr1);
3652 f(expr2);
3653 }
3654 CallVariadic { exprs, .. } => {
3655 for expr in exprs {
3656 f(expr);
3657 }
3658 }
3659 If {
3660 cond,
3661 then,
3662 els,
3663 name: _,
3664 } => {
3665 f(cond);
3666 f(then);
3667 f(els);
3668 }
3669 Exists(..) | Select(..) => (),
3670 Windowing(expr, _name) => expr.visit_mut_children(f),
3671 }
3672 }
3673
3674 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3675 where
3676 F: FnMut(&Self) -> Result<(), E>,
3677 E: From<RecursionLimitError>,
3678 {
3679 use HirScalarExpr::*;
3680 match self {
3681 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3682 CallUnary { expr, .. } => f(expr)?,
3683 CallBinary { expr1, expr2, .. } => {
3684 f(expr1)?;
3685 f(expr2)?;
3686 }
3687 CallVariadic { exprs, .. } => {
3688 for expr in exprs {
3689 f(expr)?;
3690 }
3691 }
3692 If {
3693 cond,
3694 then,
3695 els,
3696 name: _,
3697 } => {
3698 f(cond)?;
3699 f(then)?;
3700 f(els)?;
3701 }
3702 Exists(..) | Select(..) => (),
3703 Windowing(expr, _name) => expr.try_visit_children(f)?,
3704 }
3705 Ok(())
3706 }
3707
3708 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3709 where
3710 F: FnMut(&mut Self) -> Result<(), E>,
3711 E: From<RecursionLimitError>,
3712 {
3713 use HirScalarExpr::*;
3714 match self {
3715 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3716 CallUnary { expr, .. } => f(expr)?,
3717 CallBinary { expr1, expr2, .. } => {
3718 f(expr1)?;
3719 f(expr2)?;
3720 }
3721 CallVariadic { exprs, .. } => {
3722 for expr in exprs {
3723 f(expr)?;
3724 }
3725 }
3726 If {
3727 cond,
3728 then,
3729 els,
3730 name: _,
3731 } => {
3732 f(cond)?;
3733 f(then)?;
3734 f(els)?;
3735 }
3736 Exists(..) | Select(..) => (),
3737 Windowing(expr, _name) => expr.try_visit_mut_children(f)?,
3738 }
3739 Ok(())
3740 }
3741}
3742
3743impl AbstractExpr for HirScalarExpr {
3744 type Type = SqlColumnType;
3745
3746 fn typ(
3747 &self,
3748 outers: &[SqlRelationType],
3749 inner: &SqlRelationType,
3750 params: &BTreeMap<usize, SqlScalarType>,
3751 ) -> Self::Type {
3752 stack::maybe_grow(|| match self {
3753 HirScalarExpr::Column(ColumnRef { level, column }, _name) => {
3754 if *level == 0 {
3755 inner.column_types[*column].clone()
3756 } else {
3757 outers[*level - 1].column_types[*column].clone()
3758 }
3759 }
3760 HirScalarExpr::Parameter(n, _name) => params[n].clone().nullable(true),
3761 HirScalarExpr::Literal(_, typ, _name) => typ.clone(),
3762 HirScalarExpr::CallUnmaterializable(func, _name) => func.output_type(),
3763 HirScalarExpr::CallUnary {
3764 expr,
3765 func,
3766 name: _,
3767 } => func.output_type(expr.typ(outers, inner, params)),
3768 HirScalarExpr::CallBinary {
3769 expr1,
3770 expr2,
3771 func,
3772 name: _,
3773 } => func.output_type(
3774 expr1.typ(outers, inner, params),
3775 expr2.typ(outers, inner, params),
3776 ),
3777 HirScalarExpr::CallVariadic {
3778 exprs,
3779 func,
3780 name: _,
3781 } => func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect()),
3782 HirScalarExpr::If {
3783 cond: _,
3784 then,
3785 els,
3786 name: _,
3787 } => {
3788 let then_type = then.typ(outers, inner, params);
3789 let else_type = els.typ(outers, inner, params);
3790 then_type.union(&else_type).unwrap()
3791 }
3792 HirScalarExpr::Exists(_, _name) => SqlScalarType::Bool.nullable(true),
3793 HirScalarExpr::Select(expr, _name) => {
3794 let mut outers = outers.to_vec();
3795 outers.insert(0, inner.clone());
3796 expr.typ(&outers, params)
3797 .column_types
3798 .into_element()
3799 .nullable(true)
3800 }
3801 HirScalarExpr::Windowing(expr, _name) => expr.func.typ(outers, inner, params),
3802 })
3803 }
3804}
3805
3806impl AggregateExpr {
3807 pub fn typ(
3808 &self,
3809 outers: &[SqlRelationType],
3810 inner: &SqlRelationType,
3811 params: &BTreeMap<usize, SqlScalarType>,
3812 ) -> SqlColumnType {
3813 self.func.output_type(self.expr.typ(outers, inner, params))
3814 }
3815
3816 pub fn is_count_asterisk(&self) -> bool {
3824 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3825 }
3826}