1use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::{Display, Formatter};
16use std::{fmt, mem};
17
18use itertools::Itertools;
19use mz_expr::virtual_syntax::{AlgExcept, Except, IR};
20use mz_expr::visit::{Visit, VisitChildren};
21use mz_expr::{CollectionPlan, Id, LetRecLimit, RowSetFinishing, func};
22use mz_expr::AggregateFunc::{FusedWindowAggregate, WindowAggregate};
24pub use mz_expr::{
25 BinaryFunc, ColumnOrder, TableFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc, WindowFrame,
26};
27use mz_ore::collections::CollectionExt;
28use mz_ore::stack;
29use mz_ore::stack::RecursionLimitError;
30use mz_ore::str::separated;
31use mz_repr::adt::array::ArrayDimension;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::*;
34use serde::{Deserialize, Serialize};
35
36use crate::plan::Params;
37use crate::plan::error::PlanError;
38use crate::plan::query::ExprContext;
39use crate::plan::typeconv::{self, CastContext};
40
41use super::plan_utils::GroupSizeHints;
42
43#[allow(missing_debug_implementations)]
44pub struct Hir;
45
46impl IR for Hir {
47 type Relation = HirRelationExpr;
48 type Scalar = HirScalarExpr;
49}
50
51impl AlgExcept for Hir {
52 fn except(all: &bool, lhs: Self::Relation, rhs: Self::Relation) -> Self::Relation {
53 if *all {
54 let rhs = rhs.negate();
55 HirRelationExpr::union(lhs, rhs).threshold()
56 } else {
57 let lhs = lhs.distinct();
58 let rhs = rhs.distinct().negate();
59 HirRelationExpr::union(lhs, rhs).threshold()
60 }
61 }
62
63 fn un_except<'a>(expr: &'a Self::Relation) -> Option<Except<'a, Self>> {
64 let mut result = None;
65
66 use HirRelationExpr::*;
67 if let Threshold { input } = expr {
68 if let Union { base: lhs, inputs } = input.as_ref() {
69 if let [rhs] = &inputs[..] {
70 if let Negate { input: rhs } = rhs {
71 match (lhs.as_ref(), rhs.as_ref()) {
72 (Distinct { input: lhs }, Distinct { input: rhs }) => {
73 let all = false;
74 let lhs = lhs.as_ref();
75 let rhs = rhs.as_ref();
76 result = Some(Except { all, lhs, rhs })
77 }
78 (lhs, rhs) => {
79 let all = true;
80 result = Some(Except { all, lhs, rhs })
81 }
82 }
83 }
84 }
85 }
86 }
87
88 result
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
93pub enum HirRelationExpr {
95 Constant {
96 rows: Vec<Row>,
97 typ: RelationType,
98 },
99 Get {
100 id: mz_expr::Id,
101 typ: RelationType,
102 },
103 LetRec {
105 limit: Option<LetRecLimit>,
107 bindings: Vec<(String, mz_expr::LocalId, HirRelationExpr, RelationType)>,
109 body: Box<HirRelationExpr>,
111 },
112 Let {
114 name: String,
115 id: mz_expr::LocalId,
117 value: Box<HirRelationExpr>,
119 body: Box<HirRelationExpr>,
121 },
122 Project {
123 input: Box<HirRelationExpr>,
124 outputs: Vec<usize>,
125 },
126 Map {
127 input: Box<HirRelationExpr>,
128 scalars: Vec<HirScalarExpr>,
129 },
130 CallTable {
131 func: TableFunc,
132 exprs: Vec<HirScalarExpr>,
133 },
134 Filter {
135 input: Box<HirRelationExpr>,
136 predicates: Vec<HirScalarExpr>,
137 },
138 Join {
141 left: Box<HirRelationExpr>,
142 right: Box<HirRelationExpr>,
143 on: HirScalarExpr,
144 kind: JoinKind,
145 },
146 Reduce {
150 input: Box<HirRelationExpr>,
151 group_key: Vec<usize>,
152 aggregates: Vec<AggregateExpr>,
153 expected_group_size: Option<u64>,
154 },
155 Distinct {
156 input: Box<HirRelationExpr>,
157 },
158 TopK {
160 input: Box<HirRelationExpr>,
162 group_key: Vec<usize>,
164 order_key: Vec<ColumnOrder>,
166 limit: Option<HirScalarExpr>,
168 offset: usize,
170 expected_group_size: Option<u64>,
172 },
173 Negate {
174 input: Box<HirRelationExpr>,
175 },
176 Threshold {
178 input: Box<HirRelationExpr>,
179 },
180 Union {
181 base: Box<HirRelationExpr>,
182 inputs: Vec<HirRelationExpr>,
183 },
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
187pub enum HirScalarExpr {
189 Column(ColumnRef),
193 Parameter(usize),
194 Literal(Row, ColumnType),
195 CallUnmaterializable(UnmaterializableFunc),
196 CallUnary {
197 func: UnaryFunc,
198 expr: Box<HirScalarExpr>,
199 },
200 CallBinary {
201 func: BinaryFunc,
202 expr1: Box<HirScalarExpr>,
203 expr2: Box<HirScalarExpr>,
204 },
205 CallVariadic {
206 func: VariadicFunc,
207 exprs: Vec<HirScalarExpr>,
208 },
209 If {
210 cond: Box<HirScalarExpr>,
211 then: Box<HirScalarExpr>,
212 els: Box<HirScalarExpr>,
213 },
214 Exists(Box<HirRelationExpr>),
216 Select(Box<HirRelationExpr>),
226 Windowing(WindowExpr),
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
230pub struct WindowExpr {
233 pub func: WindowExprType,
234 pub partition_by: Vec<HirScalarExpr>,
235 pub order_by: Vec<HirScalarExpr>,
246}
247
248impl WindowExpr {
249 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
250 where
251 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
252 {
253 #[allow(deprecated)]
254 self.func.visit_expressions(f)?;
255 for expr in self.partition_by.iter() {
256 f(expr)?;
257 }
258 for expr in self.order_by.iter() {
259 f(expr)?;
260 }
261 Ok(())
262 }
263
264 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
265 where
266 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
267 {
268 #[allow(deprecated)]
269 self.func.visit_expressions_mut(f)?;
270 for expr in self.partition_by.iter_mut() {
271 f(expr)?;
272 }
273 for expr in self.order_by.iter_mut() {
274 f(expr)?;
275 }
276 Ok(())
277 }
278}
279
280impl VisitChildren<HirScalarExpr> for WindowExpr {
281 fn visit_children<F>(&self, mut f: F)
282 where
283 F: FnMut(&HirScalarExpr),
284 {
285 self.func.visit_children(&mut f);
286 for expr in self.partition_by.iter() {
287 f(expr);
288 }
289 for expr in self.order_by.iter() {
290 f(expr);
291 }
292 }
293
294 fn visit_mut_children<F>(&mut self, mut f: F)
295 where
296 F: FnMut(&mut HirScalarExpr),
297 {
298 self.func.visit_mut_children(&mut f);
299 for expr in self.partition_by.iter_mut() {
300 f(expr);
301 }
302 for expr in self.order_by.iter_mut() {
303 f(expr);
304 }
305 }
306
307 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
308 where
309 F: FnMut(&HirScalarExpr) -> Result<(), E>,
310 E: From<RecursionLimitError>,
311 {
312 self.func.try_visit_children(&mut f)?;
313 for expr in self.partition_by.iter() {
314 f(expr)?;
315 }
316 for expr in self.order_by.iter() {
317 f(expr)?;
318 }
319 Ok(())
320 }
321
322 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
323 where
324 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
325 E: From<RecursionLimitError>,
326 {
327 self.func.try_visit_mut_children(&mut f)?;
328 for expr in self.partition_by.iter_mut() {
329 f(expr)?;
330 }
331 for expr in self.order_by.iter_mut() {
332 f(expr)?;
333 }
334 Ok(())
335 }
336}
337
338#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
339pub enum WindowExprType {
356 Scalar(ScalarWindowExpr),
357 Value(ValueWindowExpr),
358 Aggregate(AggregateWindowExpr),
359}
360
361impl WindowExprType {
362 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
363 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
364 where
365 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
366 {
367 #[allow(deprecated)]
368 match self {
369 Self::Scalar(expr) => expr.visit_expressions(f),
370 Self::Value(expr) => expr.visit_expressions(f),
371 Self::Aggregate(expr) => expr.visit_expressions(f),
372 }
373 }
374
375 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
376 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
377 where
378 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
379 {
380 #[allow(deprecated)]
381 match self {
382 Self::Scalar(expr) => expr.visit_expressions_mut(f),
383 Self::Value(expr) => expr.visit_expressions_mut(f),
384 Self::Aggregate(expr) => expr.visit_expressions_mut(f),
385 }
386 }
387
388 fn typ(
389 &self,
390 outers: &[RelationType],
391 inner: &RelationType,
392 params: &BTreeMap<usize, ScalarType>,
393 ) -> ColumnType {
394 match self {
395 Self::Scalar(expr) => expr.typ(outers, inner, params),
396 Self::Value(expr) => expr.typ(outers, inner, params),
397 Self::Aggregate(expr) => expr.typ(outers, inner, params),
398 }
399 }
400}
401
402impl VisitChildren<HirScalarExpr> for WindowExprType {
403 fn visit_children<F>(&self, f: F)
404 where
405 F: FnMut(&HirScalarExpr),
406 {
407 match self {
408 Self::Scalar(_) => (),
409 Self::Value(expr) => expr.visit_children(f),
410 Self::Aggregate(expr) => expr.visit_children(f),
411 }
412 }
413
414 fn visit_mut_children<F>(&mut self, f: F)
415 where
416 F: FnMut(&mut HirScalarExpr),
417 {
418 match self {
419 Self::Scalar(_) => (),
420 Self::Value(expr) => expr.visit_mut_children(f),
421 Self::Aggregate(expr) => expr.visit_mut_children(f),
422 }
423 }
424
425 fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
426 where
427 F: FnMut(&HirScalarExpr) -> Result<(), E>,
428 E: From<RecursionLimitError>,
429 {
430 match self {
431 Self::Scalar(_) => Ok(()),
432 Self::Value(expr) => expr.try_visit_children(f),
433 Self::Aggregate(expr) => expr.try_visit_children(f),
434 }
435 }
436
437 fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
438 where
439 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
440 E: From<RecursionLimitError>,
441 {
442 match self {
443 Self::Scalar(_) => Ok(()),
444 Self::Value(expr) => expr.try_visit_mut_children(f),
445 Self::Aggregate(expr) => expr.try_visit_mut_children(f),
446 }
447 }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
451pub struct ScalarWindowExpr {
452 pub func: ScalarWindowFunc,
453 pub order_by: Vec<ColumnOrder>,
454}
455
456impl ScalarWindowExpr {
457 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
458 pub fn visit_expressions<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
459 where
460 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
461 {
462 match self.func {
463 ScalarWindowFunc::RowNumber => {}
464 ScalarWindowFunc::Rank => {}
465 ScalarWindowFunc::DenseRank => {}
466 }
467 Ok(())
468 }
469
470 #[deprecated = "Implement `VisitChildren<HirScalarExpr>` if needed."]
471 pub fn visit_expressions_mut<'a, F, E>(&'a self, _f: &mut F) -> Result<(), E>
472 where
473 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
474 {
475 match self.func {
476 ScalarWindowFunc::RowNumber => {}
477 ScalarWindowFunc::Rank => {}
478 ScalarWindowFunc::DenseRank => {}
479 }
480 Ok(())
481 }
482
483 fn typ(
484 &self,
485 _outers: &[RelationType],
486 _inner: &RelationType,
487 _params: &BTreeMap<usize, ScalarType>,
488 ) -> ColumnType {
489 self.func.output_type()
490 }
491
492 pub fn into_expr(self) -> mz_expr::AggregateFunc {
493 match self.func {
494 ScalarWindowFunc::RowNumber => mz_expr::AggregateFunc::RowNumber {
495 order_by: self.order_by,
496 },
497 ScalarWindowFunc::Rank => mz_expr::AggregateFunc::Rank {
498 order_by: self.order_by,
499 },
500 ScalarWindowFunc::DenseRank => mz_expr::AggregateFunc::DenseRank {
501 order_by: self.order_by,
502 },
503 }
504 }
505}
506
507#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
508pub enum ScalarWindowFunc {
510 RowNumber,
511 Rank,
512 DenseRank,
513}
514
515impl Display for ScalarWindowFunc {
516 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
517 match self {
518 ScalarWindowFunc::RowNumber => write!(f, "row_number"),
519 ScalarWindowFunc::Rank => write!(f, "rank"),
520 ScalarWindowFunc::DenseRank => write!(f, "dense_rank"),
521 }
522 }
523}
524
525impl ScalarWindowFunc {
526 pub fn output_type(&self) -> ColumnType {
527 match self {
528 ScalarWindowFunc::RowNumber => ScalarType::Int64.nullable(false),
529 ScalarWindowFunc::Rank => ScalarType::Int64.nullable(false),
530 ScalarWindowFunc::DenseRank => ScalarType::Int64.nullable(false),
531 }
532 }
533}
534
535#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
536pub struct ValueWindowExpr {
537 pub func: ValueWindowFunc,
538 pub args: Box<HirScalarExpr>,
544 pub order_by: Vec<ColumnOrder>,
546 pub window_frame: WindowFrame,
547 pub ignore_nulls: bool,
548}
549
550impl Display for ValueWindowFunc {
551 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
552 match self {
553 ValueWindowFunc::Lag => write!(f, "lag"),
554 ValueWindowFunc::Lead => write!(f, "lead"),
555 ValueWindowFunc::FirstValue => write!(f, "first_value"),
556 ValueWindowFunc::LastValue => write!(f, "last_value"),
557 ValueWindowFunc::Fused(funcs) => write!(f, "fused[{}]", separated(", ", funcs)),
558 }
559 }
560}
561
562impl ValueWindowExpr {
563 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
564 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
565 where
566 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
567 {
568 f(&self.args)
569 }
570
571 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
572 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
573 where
574 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
575 {
576 f(&mut self.args)
577 }
578
579 fn typ(
580 &self,
581 outers: &[RelationType],
582 inner: &RelationType,
583 params: &BTreeMap<usize, ScalarType>,
584 ) -> ColumnType {
585 self.func.output_type(self.args.typ(outers, inner, params))
586 }
587
588 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
590 (
591 self.args,
592 self.func
593 .into_expr(self.order_by, self.window_frame, self.ignore_nulls),
594 )
595 }
596}
597
598impl VisitChildren<HirScalarExpr> for ValueWindowExpr {
599 fn visit_children<F>(&self, mut f: F)
600 where
601 F: FnMut(&HirScalarExpr),
602 {
603 f(&self.args)
604 }
605
606 fn visit_mut_children<F>(&mut self, mut f: F)
607 where
608 F: FnMut(&mut HirScalarExpr),
609 {
610 f(&mut self.args)
611 }
612
613 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
614 where
615 F: FnMut(&HirScalarExpr) -> Result<(), E>,
616 E: From<RecursionLimitError>,
617 {
618 f(&self.args)
619 }
620
621 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
622 where
623 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
624 E: From<RecursionLimitError>,
625 {
626 f(&mut self.args)
627 }
628}
629
630#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
631pub enum ValueWindowFunc {
633 Lag,
634 Lead,
635 FirstValue,
636 LastValue,
637 Fused(Vec<ValueWindowFunc>),
638}
639
640impl ValueWindowFunc {
641 pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
642 match self {
643 ValueWindowFunc::Lag | ValueWindowFunc::Lead => {
644 input_type.scalar_type.unwrap_record_element_type()[0]
646 .clone()
647 .nullable(true)
648 }
649 ValueWindowFunc::FirstValue | ValueWindowFunc::LastValue => {
650 input_type.scalar_type.nullable(true)
651 }
652 ValueWindowFunc::Fused(funcs) => {
653 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
654 ScalarType::Record {
655 fields: funcs
656 .iter()
657 .zip_eq(input_types)
658 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
659 .collect(),
660 custom_id: None,
661 }
662 .nullable(false)
663 }
664 }
665 }
666
667 pub fn into_expr(
668 self,
669 order_by: Vec<ColumnOrder>,
670 window_frame: WindowFrame,
671 ignore_nulls: bool,
672 ) -> mz_expr::AggregateFunc {
673 match self {
674 ValueWindowFunc::Lag => mz_expr::AggregateFunc::LagLead {
676 order_by,
677 lag_lead: mz_expr::LagLeadType::Lag,
678 ignore_nulls,
679 },
680 ValueWindowFunc::Lead => mz_expr::AggregateFunc::LagLead {
681 order_by,
682 lag_lead: mz_expr::LagLeadType::Lead,
683 ignore_nulls,
684 },
685 ValueWindowFunc::FirstValue => mz_expr::AggregateFunc::FirstValue {
686 order_by,
687 window_frame,
688 },
689 ValueWindowFunc::LastValue => mz_expr::AggregateFunc::LastValue {
690 order_by,
691 window_frame,
692 },
693 ValueWindowFunc::Fused(funcs) => mz_expr::AggregateFunc::FusedValueWindowFunc {
694 funcs: funcs
695 .into_iter()
696 .map(|func| {
697 func.into_expr(order_by.clone(), window_frame.clone(), ignore_nulls)
698 })
699 .collect(),
700 order_by,
701 },
702 }
703 }
704}
705
706#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
707pub struct AggregateWindowExpr {
708 pub aggregate_expr: AggregateExpr,
709 pub order_by: Vec<ColumnOrder>,
710 pub window_frame: WindowFrame,
711}
712
713impl AggregateWindowExpr {
714 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
715 pub fn visit_expressions<'a, F, E>(&'a self, f: &mut F) -> Result<(), E>
716 where
717 F: FnMut(&'a HirScalarExpr) -> Result<(), E>,
718 {
719 f(&self.aggregate_expr.expr)
720 }
721
722 #[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_mut_children` instead."]
723 pub fn visit_expressions_mut<'a, F, E>(&'a mut self, f: &mut F) -> Result<(), E>
724 where
725 F: FnMut(&'a mut HirScalarExpr) -> Result<(), E>,
726 {
727 f(&mut self.aggregate_expr.expr)
728 }
729
730 fn typ(
731 &self,
732 outers: &[RelationType],
733 inner: &RelationType,
734 params: &BTreeMap<usize, ScalarType>,
735 ) -> ColumnType {
736 self.aggregate_expr
737 .func
738 .output_type(self.aggregate_expr.expr.typ(outers, inner, params))
739 }
740
741 pub fn into_expr(self) -> (Box<HirScalarExpr>, mz_expr::AggregateFunc) {
742 if let AggregateFunc::FusedWindowAgg { funcs } = &self.aggregate_expr.func {
743 (
744 self.aggregate_expr.expr,
745 FusedWindowAggregate {
746 wrapped_aggregates: funcs.iter().map(|f| f.clone().into_expr()).collect(),
747 order_by: self.order_by,
748 window_frame: self.window_frame,
749 },
750 )
751 } else {
752 (
753 self.aggregate_expr.expr,
754 WindowAggregate {
755 wrapped_aggregate: Box::new(self.aggregate_expr.func.into_expr()),
756 order_by: self.order_by,
757 window_frame: self.window_frame,
758 },
759 )
760 }
761 }
762}
763
764impl VisitChildren<HirScalarExpr> for AggregateWindowExpr {
765 fn visit_children<F>(&self, mut f: F)
766 where
767 F: FnMut(&HirScalarExpr),
768 {
769 f(&self.aggregate_expr.expr)
770 }
771
772 fn visit_mut_children<F>(&mut self, mut f: F)
773 where
774 F: FnMut(&mut HirScalarExpr),
775 {
776 f(&mut self.aggregate_expr.expr)
777 }
778
779 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
780 where
781 F: FnMut(&HirScalarExpr) -> Result<(), E>,
782 E: From<RecursionLimitError>,
783 {
784 f(&self.aggregate_expr.expr)
785 }
786
787 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
788 where
789 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
790 E: From<RecursionLimitError>,
791 {
792 f(&mut self.aggregate_expr.expr)
793 }
794}
795
796#[derive(Clone, Debug)]
821pub enum CoercibleScalarExpr {
822 Coerced(HirScalarExpr),
823 Parameter(usize),
824 LiteralNull,
825 LiteralString(String),
826 LiteralRecord(Vec<CoercibleScalarExpr>),
827}
828
829impl CoercibleScalarExpr {
830 pub fn type_as(self, ecx: &ExprContext, ty: &ScalarType) -> Result<HirScalarExpr, PlanError> {
831 let expr = typeconv::plan_coerce(ecx, self, ty)?;
832 let expr_ty = ecx.scalar_type(&expr);
833 if ty != &expr_ty {
834 sql_bail!(
835 "{} must have type {}, not type {}",
836 ecx.name,
837 ecx.humanize_scalar_type(ty, false),
838 ecx.humanize_scalar_type(&expr_ty, false),
839 );
840 }
841 Ok(expr)
842 }
843
844 pub fn type_as_any(self, ecx: &ExprContext) -> Result<HirScalarExpr, PlanError> {
845 typeconv::plan_coerce(ecx, self, &ScalarType::String)
846 }
847
848 pub fn cast_to(
849 self,
850 ecx: &ExprContext,
851 ccx: CastContext,
852 ty: &ScalarType,
853 ) -> Result<HirScalarExpr, PlanError> {
854 let expr = typeconv::plan_coerce(ecx, self, ty)?;
855 typeconv::plan_cast(ecx, ccx, expr, ty)
856 }
857}
858
859#[derive(Clone, Debug)]
861pub enum CoercibleColumnType {
862 Coerced(ColumnType),
863 Record(Vec<CoercibleColumnType>),
864 Uncoerced,
865}
866
867impl CoercibleColumnType {
868 pub fn nullable(&self) -> bool {
870 match self {
871 CoercibleColumnType::Coerced(ct) => ct.nullable,
873
874 CoercibleColumnType::Record(_) => false,
876
877 CoercibleColumnType::Uncoerced => true,
880 }
881 }
882}
883
884#[derive(Clone, Debug)]
886pub enum CoercibleScalarType {
887 Coerced(ScalarType),
888 Record(Vec<CoercibleColumnType>),
889 Uncoerced,
890}
891
892impl CoercibleScalarType {
893 pub fn is_coerced(&self) -> bool {
895 matches!(self, CoercibleScalarType::Coerced(_))
896 }
897
898 pub fn as_coerced(&self) -> Option<&ScalarType> {
900 match self {
901 CoercibleScalarType::Coerced(t) => Some(t),
902 _ => None,
903 }
904 }
905
906 pub fn map_coerced<F>(self, f: F) -> CoercibleScalarType
909 where
910 F: FnOnce(ScalarType) -> ScalarType,
911 {
912 match self {
913 CoercibleScalarType::Coerced(t) => CoercibleScalarType::Coerced(f(t)),
914 _ => self,
915 }
916 }
917
918 pub fn force_coerced_if_record(&mut self) {
925 fn convert(uncoerced_fields: impl Iterator<Item = CoercibleColumnType>) -> ScalarType {
926 let mut fields = vec![];
927 for (i, uf) in uncoerced_fields.enumerate() {
928 let name = ColumnName::from(format!("f{}", i + 1));
929 let ty = match uf {
930 CoercibleColumnType::Coerced(ty) => ty,
931 CoercibleColumnType::Record(mut fields) => {
932 convert(fields.drain(..)).nullable(false)
933 }
934 CoercibleColumnType::Uncoerced => ScalarType::String.nullable(true),
935 };
936 fields.push((name, ty))
937 }
938 ScalarType::Record {
939 fields: fields.into(),
940 custom_id: None,
941 }
942 }
943
944 if let CoercibleScalarType::Record(fields) = self {
945 *self = CoercibleScalarType::Coerced(convert(fields.drain(..)));
946 }
947 }
948}
949
950pub trait AbstractExpr {
954 type Type: AbstractColumnType;
955
956 fn typ(
958 &self,
959 outers: &[RelationType],
960 inner: &RelationType,
961 params: &BTreeMap<usize, ScalarType>,
962 ) -> Self::Type;
963}
964
965impl AbstractExpr for CoercibleScalarExpr {
966 type Type = CoercibleColumnType;
967
968 fn typ(
969 &self,
970 outers: &[RelationType],
971 inner: &RelationType,
972 params: &BTreeMap<usize, ScalarType>,
973 ) -> Self::Type {
974 match self {
975 CoercibleScalarExpr::Coerced(expr) => {
976 CoercibleColumnType::Coerced(expr.typ(outers, inner, params))
977 }
978 CoercibleScalarExpr::LiteralRecord(scalars) => {
979 let fields = scalars
980 .iter()
981 .map(|s| s.typ(outers, inner, params))
982 .collect();
983 CoercibleColumnType::Record(fields)
984 }
985 _ => CoercibleColumnType::Uncoerced,
986 }
987 }
988}
989
990pub trait AbstractColumnType {
995 type AbstractScalarType;
996
997 fn scalar_type(self) -> Self::AbstractScalarType;
1000}
1001
1002impl AbstractColumnType for ColumnType {
1003 type AbstractScalarType = ScalarType;
1004
1005 fn scalar_type(self) -> Self::AbstractScalarType {
1006 self.scalar_type
1007 }
1008}
1009
1010impl AbstractColumnType for CoercibleColumnType {
1011 type AbstractScalarType = CoercibleScalarType;
1012
1013 fn scalar_type(self) -> Self::AbstractScalarType {
1014 match self {
1015 CoercibleColumnType::Coerced(t) => CoercibleScalarType::Coerced(t.scalar_type),
1016 CoercibleColumnType::Record(t) => CoercibleScalarType::Record(t),
1017 CoercibleColumnType::Uncoerced => CoercibleScalarType::Uncoerced,
1018 }
1019 }
1020}
1021
1022impl From<HirScalarExpr> for CoercibleScalarExpr {
1023 fn from(expr: HirScalarExpr) -> CoercibleScalarExpr {
1024 CoercibleScalarExpr::Coerced(expr)
1025 }
1026}
1027
1028#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
1043pub struct ColumnRef {
1044 pub level: usize,
1046 pub column: usize,
1048}
1049
1050#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1051pub enum JoinKind {
1052 Inner,
1053 LeftOuter,
1054 RightOuter,
1055 FullOuter,
1056}
1057
1058impl fmt::Display for JoinKind {
1059 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1060 write!(
1061 f,
1062 "{}",
1063 match self {
1064 JoinKind::Inner => "Inner",
1065 JoinKind::LeftOuter => "LeftOuter",
1066 JoinKind::RightOuter => "RightOuter",
1067 JoinKind::FullOuter => "FullOuter",
1068 }
1069 )
1070 }
1071}
1072
1073impl JoinKind {
1074 pub fn can_be_correlated(&self) -> bool {
1075 match self {
1076 JoinKind::Inner | JoinKind::LeftOuter => true,
1077 JoinKind::RightOuter | JoinKind::FullOuter => false,
1078 }
1079 }
1080}
1081
1082#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1083pub struct AggregateExpr {
1084 pub func: AggregateFunc,
1085 pub expr: Box<HirScalarExpr>,
1086 pub distinct: bool,
1087}
1088
1089#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
1097pub enum AggregateFunc {
1098 MaxNumeric,
1099 MaxInt16,
1100 MaxInt32,
1101 MaxInt64,
1102 MaxUInt16,
1103 MaxUInt32,
1104 MaxUInt64,
1105 MaxMzTimestamp,
1106 MaxFloat32,
1107 MaxFloat64,
1108 MaxBool,
1109 MaxString,
1110 MaxDate,
1111 MaxTimestamp,
1112 MaxTimestampTz,
1113 MaxInterval,
1114 MaxTime,
1115 MinNumeric,
1116 MinInt16,
1117 MinInt32,
1118 MinInt64,
1119 MinUInt16,
1120 MinUInt32,
1121 MinUInt64,
1122 MinMzTimestamp,
1123 MinFloat32,
1124 MinFloat64,
1125 MinBool,
1126 MinString,
1127 MinDate,
1128 MinTimestamp,
1129 MinTimestampTz,
1130 MinInterval,
1131 MinTime,
1132 SumInt16,
1133 SumInt32,
1134 SumInt64,
1135 SumUInt16,
1136 SumUInt32,
1137 SumUInt64,
1138 SumFloat32,
1139 SumFloat64,
1140 SumNumeric,
1141 Count,
1142 Any,
1143 All,
1144 JsonbAgg {
1151 order_by: Vec<ColumnOrder>,
1152 },
1153 JsonbObjectAgg {
1156 order_by: Vec<ColumnOrder>,
1157 },
1158 MapAgg {
1162 order_by: Vec<ColumnOrder>,
1163 value_type: ScalarType,
1164 },
1165 ArrayConcat {
1168 order_by: Vec<ColumnOrder>,
1169 },
1170 ListConcat {
1173 order_by: Vec<ColumnOrder>,
1174 },
1175 StringAgg {
1176 order_by: Vec<ColumnOrder>,
1177 },
1178 FusedWindowAgg {
1184 funcs: Vec<AggregateFunc>,
1185 },
1186 Dummy,
1191}
1192
1193impl AggregateFunc {
1194 pub fn into_expr(self) -> mz_expr::AggregateFunc {
1196 match self {
1197 AggregateFunc::MaxNumeric => mz_expr::AggregateFunc::MaxNumeric,
1198 AggregateFunc::MaxInt16 => mz_expr::AggregateFunc::MaxInt16,
1199 AggregateFunc::MaxInt32 => mz_expr::AggregateFunc::MaxInt32,
1200 AggregateFunc::MaxInt64 => mz_expr::AggregateFunc::MaxInt64,
1201 AggregateFunc::MaxUInt16 => mz_expr::AggregateFunc::MaxUInt16,
1202 AggregateFunc::MaxUInt32 => mz_expr::AggregateFunc::MaxUInt32,
1203 AggregateFunc::MaxUInt64 => mz_expr::AggregateFunc::MaxUInt64,
1204 AggregateFunc::MaxMzTimestamp => mz_expr::AggregateFunc::MaxMzTimestamp,
1205 AggregateFunc::MaxFloat32 => mz_expr::AggregateFunc::MaxFloat32,
1206 AggregateFunc::MaxFloat64 => mz_expr::AggregateFunc::MaxFloat64,
1207 AggregateFunc::MaxBool => mz_expr::AggregateFunc::MaxBool,
1208 AggregateFunc::MaxString => mz_expr::AggregateFunc::MaxString,
1209 AggregateFunc::MaxDate => mz_expr::AggregateFunc::MaxDate,
1210 AggregateFunc::MaxTimestamp => mz_expr::AggregateFunc::MaxTimestamp,
1211 AggregateFunc::MaxTimestampTz => mz_expr::AggregateFunc::MaxTimestampTz,
1212 AggregateFunc::MaxInterval => mz_expr::AggregateFunc::MaxInterval,
1213 AggregateFunc::MaxTime => mz_expr::AggregateFunc::MaxTime,
1214 AggregateFunc::MinNumeric => mz_expr::AggregateFunc::MinNumeric,
1215 AggregateFunc::MinInt16 => mz_expr::AggregateFunc::MinInt16,
1216 AggregateFunc::MinInt32 => mz_expr::AggregateFunc::MinInt32,
1217 AggregateFunc::MinInt64 => mz_expr::AggregateFunc::MinInt64,
1218 AggregateFunc::MinUInt16 => mz_expr::AggregateFunc::MinUInt16,
1219 AggregateFunc::MinUInt32 => mz_expr::AggregateFunc::MinUInt32,
1220 AggregateFunc::MinUInt64 => mz_expr::AggregateFunc::MinUInt64,
1221 AggregateFunc::MinMzTimestamp => mz_expr::AggregateFunc::MinMzTimestamp,
1222 AggregateFunc::MinFloat32 => mz_expr::AggregateFunc::MinFloat32,
1223 AggregateFunc::MinFloat64 => mz_expr::AggregateFunc::MinFloat64,
1224 AggregateFunc::MinBool => mz_expr::AggregateFunc::MinBool,
1225 AggregateFunc::MinString => mz_expr::AggregateFunc::MinString,
1226 AggregateFunc::MinDate => mz_expr::AggregateFunc::MinDate,
1227 AggregateFunc::MinTimestamp => mz_expr::AggregateFunc::MinTimestamp,
1228 AggregateFunc::MinTimestampTz => mz_expr::AggregateFunc::MinTimestampTz,
1229 AggregateFunc::MinInterval => mz_expr::AggregateFunc::MinInterval,
1230 AggregateFunc::MinTime => mz_expr::AggregateFunc::MinTime,
1231 AggregateFunc::SumInt16 => mz_expr::AggregateFunc::SumInt16,
1232 AggregateFunc::SumInt32 => mz_expr::AggregateFunc::SumInt32,
1233 AggregateFunc::SumInt64 => mz_expr::AggregateFunc::SumInt64,
1234 AggregateFunc::SumUInt16 => mz_expr::AggregateFunc::SumUInt16,
1235 AggregateFunc::SumUInt32 => mz_expr::AggregateFunc::SumUInt32,
1236 AggregateFunc::SumUInt64 => mz_expr::AggregateFunc::SumUInt64,
1237 AggregateFunc::SumFloat32 => mz_expr::AggregateFunc::SumFloat32,
1238 AggregateFunc::SumFloat64 => mz_expr::AggregateFunc::SumFloat64,
1239 AggregateFunc::SumNumeric => mz_expr::AggregateFunc::SumNumeric,
1240 AggregateFunc::Count => mz_expr::AggregateFunc::Count,
1241 AggregateFunc::Any => mz_expr::AggregateFunc::Any,
1242 AggregateFunc::All => mz_expr::AggregateFunc::All,
1243 AggregateFunc::JsonbAgg { order_by } => mz_expr::AggregateFunc::JsonbAgg { order_by },
1244 AggregateFunc::JsonbObjectAgg { order_by } => {
1245 mz_expr::AggregateFunc::JsonbObjectAgg { order_by }
1246 }
1247 AggregateFunc::MapAgg {
1248 order_by,
1249 value_type,
1250 } => mz_expr::AggregateFunc::MapAgg {
1251 order_by,
1252 value_type,
1253 },
1254 AggregateFunc::ArrayConcat { order_by } => {
1255 mz_expr::AggregateFunc::ArrayConcat { order_by }
1256 }
1257 AggregateFunc::ListConcat { order_by } => {
1258 mz_expr::AggregateFunc::ListConcat { order_by }
1259 }
1260 AggregateFunc::StringAgg { order_by } => mz_expr::AggregateFunc::StringAgg { order_by },
1261 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1264 panic!("into_expr called on FusedWindowAgg")
1265 }
1266 AggregateFunc::Dummy => mz_expr::AggregateFunc::Dummy,
1267 }
1268 }
1269
1270 pub fn identity_datum(&self) -> Datum<'static> {
1277 match self {
1278 AggregateFunc::Any => Datum::False,
1279 AggregateFunc::All => Datum::True,
1280 AggregateFunc::Dummy => Datum::Dummy,
1281 AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
1282 AggregateFunc::ListConcat { .. } => Datum::empty_list(),
1283 AggregateFunc::MaxNumeric
1284 | AggregateFunc::MaxInt16
1285 | AggregateFunc::MaxInt32
1286 | AggregateFunc::MaxInt64
1287 | AggregateFunc::MaxUInt16
1288 | AggregateFunc::MaxUInt32
1289 | AggregateFunc::MaxUInt64
1290 | AggregateFunc::MaxMzTimestamp
1291 | AggregateFunc::MaxFloat32
1292 | AggregateFunc::MaxFloat64
1293 | AggregateFunc::MaxBool
1294 | AggregateFunc::MaxString
1295 | AggregateFunc::MaxDate
1296 | AggregateFunc::MaxTimestamp
1297 | AggregateFunc::MaxTimestampTz
1298 | AggregateFunc::MaxInterval
1299 | AggregateFunc::MaxTime
1300 | AggregateFunc::MinNumeric
1301 | AggregateFunc::MinInt16
1302 | AggregateFunc::MinInt32
1303 | AggregateFunc::MinInt64
1304 | AggregateFunc::MinUInt16
1305 | AggregateFunc::MinUInt32
1306 | AggregateFunc::MinUInt64
1307 | AggregateFunc::MinMzTimestamp
1308 | AggregateFunc::MinFloat32
1309 | AggregateFunc::MinFloat64
1310 | AggregateFunc::MinBool
1311 | AggregateFunc::MinString
1312 | AggregateFunc::MinDate
1313 | AggregateFunc::MinTimestamp
1314 | AggregateFunc::MinTimestampTz
1315 | AggregateFunc::MinInterval
1316 | AggregateFunc::MinTime
1317 | AggregateFunc::SumInt16
1318 | AggregateFunc::SumInt32
1319 | AggregateFunc::SumInt64
1320 | AggregateFunc::SumUInt16
1321 | AggregateFunc::SumUInt32
1322 | AggregateFunc::SumUInt64
1323 | AggregateFunc::SumFloat32
1324 | AggregateFunc::SumFloat64
1325 | AggregateFunc::SumNumeric
1326 | AggregateFunc::Count
1327 | AggregateFunc::JsonbAgg { .. }
1328 | AggregateFunc::JsonbObjectAgg { .. }
1329 | AggregateFunc::MapAgg { .. }
1330 | AggregateFunc::StringAgg { .. } => Datum::Null,
1331 AggregateFunc::FusedWindowAgg { funcs: _ } => {
1332 panic!("FusedWindowAgg doesn't have an identity_datum")
1342 }
1343 }
1344 }
1345
1346 pub fn output_type(&self, input_type: ColumnType) -> ColumnType {
1352 let scalar_type = match self {
1353 AggregateFunc::Count => ScalarType::Int64,
1354 AggregateFunc::Any => ScalarType::Bool,
1355 AggregateFunc::All => ScalarType::Bool,
1356 AggregateFunc::JsonbAgg { .. } => ScalarType::Jsonb,
1357 AggregateFunc::JsonbObjectAgg { .. } => ScalarType::Jsonb,
1358 AggregateFunc::StringAgg { .. } => ScalarType::String,
1359 AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => ScalarType::Int64,
1360 AggregateFunc::SumInt64 => ScalarType::Numeric {
1361 max_scale: Some(NumericMaxScale::ZERO),
1362 },
1363 AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 => ScalarType::UInt64,
1364 AggregateFunc::SumUInt64 => ScalarType::Numeric {
1365 max_scale: Some(NumericMaxScale::ZERO),
1366 },
1367 AggregateFunc::MapAgg { value_type, .. } => ScalarType::Map {
1368 value_type: Box::new(value_type.clone()),
1369 custom_id: None,
1370 },
1371 AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
1372 match input_type.scalar_type {
1373 ScalarType::Record { fields, .. } => fields[0].1.scalar_type.clone(),
1375 _ => unreachable!(),
1376 }
1377 }
1378 AggregateFunc::MaxNumeric
1379 | AggregateFunc::MaxInt16
1380 | AggregateFunc::MaxInt32
1381 | AggregateFunc::MaxInt64
1382 | AggregateFunc::MaxUInt16
1383 | AggregateFunc::MaxUInt32
1384 | AggregateFunc::MaxUInt64
1385 | AggregateFunc::MaxMzTimestamp
1386 | AggregateFunc::MaxFloat32
1387 | AggregateFunc::MaxFloat64
1388 | AggregateFunc::MaxBool
1389 | AggregateFunc::MaxString
1390 | AggregateFunc::MaxDate
1391 | AggregateFunc::MaxTimestamp
1392 | AggregateFunc::MaxTimestampTz
1393 | AggregateFunc::MaxInterval
1394 | AggregateFunc::MaxTime
1395 | AggregateFunc::MinNumeric
1396 | AggregateFunc::MinInt16
1397 | AggregateFunc::MinInt32
1398 | AggregateFunc::MinInt64
1399 | AggregateFunc::MinUInt16
1400 | AggregateFunc::MinUInt32
1401 | AggregateFunc::MinUInt64
1402 | AggregateFunc::MinMzTimestamp
1403 | AggregateFunc::MinFloat32
1404 | AggregateFunc::MinFloat64
1405 | AggregateFunc::MinBool
1406 | AggregateFunc::MinString
1407 | AggregateFunc::MinDate
1408 | AggregateFunc::MinTimestamp
1409 | AggregateFunc::MinTimestampTz
1410 | AggregateFunc::MinInterval
1411 | AggregateFunc::MinTime
1412 | AggregateFunc::SumFloat32
1413 | AggregateFunc::SumFloat64
1414 | AggregateFunc::SumNumeric
1415 | AggregateFunc::Dummy => input_type.scalar_type,
1416 AggregateFunc::FusedWindowAgg { funcs } => {
1417 let input_types = input_type.scalar_type.unwrap_record_element_column_type();
1418 ScalarType::Record {
1419 fields: funcs
1420 .iter()
1421 .zip_eq(input_types)
1422 .map(|(f, t)| (ColumnName::from(""), f.output_type(t.clone())))
1423 .collect(),
1424 custom_id: None,
1425 }
1426 }
1427 };
1428 let nullable = !matches!(self, AggregateFunc::Count);
1430 scalar_type.nullable(nullable)
1431 }
1432
1433 pub fn is_order_sensitive(&self) -> bool {
1434 use AggregateFunc::*;
1435 matches!(
1436 self,
1437 JsonbAgg { .. }
1438 | JsonbObjectAgg { .. }
1439 | MapAgg { .. }
1440 | ArrayConcat { .. }
1441 | ListConcat { .. }
1442 | StringAgg { .. }
1443 )
1444 }
1445}
1446
1447impl HirRelationExpr {
1448 pub fn typ(
1449 &self,
1450 outers: &[RelationType],
1451 params: &BTreeMap<usize, ScalarType>,
1452 ) -> RelationType {
1453 stack::maybe_grow(|| match self {
1454 HirRelationExpr::Constant { typ, .. } => typ.clone(),
1455 HirRelationExpr::Get { typ, .. } => typ.clone(),
1456 HirRelationExpr::Let { body, .. } => body.typ(outers, params),
1457 HirRelationExpr::LetRec { body, .. } => body.typ(outers, params),
1458 HirRelationExpr::Project { input, outputs } => {
1459 let input_typ = input.typ(outers, params);
1460 RelationType::new(
1461 outputs
1462 .iter()
1463 .map(|&i| input_typ.column_types[i].clone())
1464 .collect(),
1465 )
1466 }
1467 HirRelationExpr::Map { input, scalars } => {
1468 let mut typ = input.typ(outers, params);
1469 for scalar in scalars {
1470 typ.column_types.push(scalar.typ(outers, &typ, params));
1471 }
1472 typ
1473 }
1474 HirRelationExpr::CallTable { func, exprs: _ } => func.output_type(),
1475 HirRelationExpr::Filter { input, .. } | HirRelationExpr::TopK { input, .. } => {
1476 input.typ(outers, params)
1477 }
1478 HirRelationExpr::Join {
1479 left, right, kind, ..
1480 } => {
1481 let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
1482 let right_nullable =
1483 matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
1484 let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
1485 let nullable = t.nullable || left_nullable;
1486 t.nullable(nullable)
1487 });
1488 let mut outers = outers.to_vec();
1489 outers.insert(0, RelationType::new(lt.clone().collect()));
1490 let rt = right
1491 .typ(&outers, params)
1492 .column_types
1493 .into_iter()
1494 .map(|t| {
1495 let nullable = t.nullable || right_nullable;
1496 t.nullable(nullable)
1497 });
1498 RelationType::new(lt.chain(rt).collect())
1499 }
1500 HirRelationExpr::Reduce {
1501 input,
1502 group_key,
1503 aggregates,
1504 expected_group_size: _,
1505 } => {
1506 let input_typ = input.typ(outers, params);
1507 let mut column_types = group_key
1508 .iter()
1509 .map(|&i| input_typ.column_types[i].clone())
1510 .collect::<Vec<_>>();
1511 for agg in aggregates {
1512 column_types.push(agg.typ(outers, &input_typ, params));
1513 }
1514 RelationType::new(column_types)
1516 }
1517 HirRelationExpr::Distinct { input }
1519 | HirRelationExpr::Negate { input }
1520 | HirRelationExpr::Threshold { input } => input.typ(outers, params),
1521 HirRelationExpr::Union { base, inputs } => {
1522 let mut base_cols = base.typ(outers, params).column_types;
1523 for input in inputs {
1524 for (base_col, col) in base_cols
1525 .iter_mut()
1526 .zip_eq(input.typ(outers, params).column_types)
1527 {
1528 *base_col = base_col.union(&col).unwrap();
1529 }
1530 }
1531 RelationType::new(base_cols)
1532 }
1533 })
1534 }
1535
1536 pub fn arity(&self) -> usize {
1537 match self {
1538 HirRelationExpr::Constant { typ, .. } => typ.column_types.len(),
1539 HirRelationExpr::Get { typ, .. } => typ.column_types.len(),
1540 HirRelationExpr::Let { body, .. } => body.arity(),
1541 HirRelationExpr::LetRec { body, .. } => body.arity(),
1542 HirRelationExpr::Project { outputs, .. } => outputs.len(),
1543 HirRelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
1544 HirRelationExpr::CallTable { func, .. } => func.output_arity(),
1545 HirRelationExpr::Filter { input, .. }
1546 | HirRelationExpr::TopK { input, .. }
1547 | HirRelationExpr::Distinct { input }
1548 | HirRelationExpr::Negate { input }
1549 | HirRelationExpr::Threshold { input } => input.arity(),
1550 HirRelationExpr::Join { left, right, .. } => left.arity() + right.arity(),
1551 HirRelationExpr::Union { base, .. } => base.arity(),
1552 HirRelationExpr::Reduce {
1553 group_key,
1554 aggregates,
1555 ..
1556 } => group_key.len() + aggregates.len(),
1557 }
1558 }
1559
1560 pub fn as_const(&self) -> Option<(&Vec<Row>, &RelationType)> {
1562 match self {
1563 Self::Constant { rows, typ } => Some((rows, typ)),
1564 _ => None,
1565 }
1566 }
1567
1568 pub fn is_correlated(&self) -> bool {
1571 let mut correlated = false;
1572 #[allow(deprecated)]
1573 self.visit_columns(0, &mut |depth, col| {
1574 if col.level > depth && col.level - depth == 1 {
1575 correlated = true;
1576 }
1577 });
1578 correlated
1579 }
1580
1581 pub fn is_join_identity(&self) -> bool {
1582 match self {
1583 HirRelationExpr::Constant { rows, .. } => rows.len() == 1 && self.arity() == 0,
1584 _ => false,
1585 }
1586 }
1587
1588 pub fn project(self, outputs: Vec<usize>) -> Self {
1589 if outputs.iter().copied().eq(0..self.arity()) {
1590 self
1592 } else {
1593 HirRelationExpr::Project {
1594 input: Box::new(self),
1595 outputs,
1596 }
1597 }
1598 }
1599
1600 pub fn map(mut self, scalars: Vec<HirScalarExpr>) -> Self {
1601 if scalars.is_empty() {
1602 self
1604 } else if let HirRelationExpr::Map {
1605 scalars: old_scalars,
1606 input: _,
1607 } = &mut self
1608 {
1609 old_scalars.extend(scalars);
1611 self
1612 } else {
1613 HirRelationExpr::Map {
1614 input: Box::new(self),
1615 scalars,
1616 }
1617 }
1618 }
1619
1620 pub fn filter(mut self, mut preds: Vec<HirScalarExpr>) -> Self {
1621 if let HirRelationExpr::Filter {
1622 input: _,
1623 predicates,
1624 } = &mut self
1625 {
1626 predicates.extend(preds);
1627 predicates.sort();
1628 predicates.dedup();
1629 self
1630 } else {
1631 preds.sort();
1632 preds.dedup();
1633 HirRelationExpr::Filter {
1634 input: Box::new(self),
1635 predicates: preds,
1636 }
1637 }
1638 }
1639
1640 pub fn reduce(
1641 self,
1642 group_key: Vec<usize>,
1643 aggregates: Vec<AggregateExpr>,
1644 expected_group_size: Option<u64>,
1645 ) -> Self {
1646 HirRelationExpr::Reduce {
1647 input: Box::new(self),
1648 group_key,
1649 aggregates,
1650 expected_group_size,
1651 }
1652 }
1653
1654 pub fn top_k(
1655 self,
1656 group_key: Vec<usize>,
1657 order_key: Vec<ColumnOrder>,
1658 limit: Option<HirScalarExpr>,
1659 offset: usize,
1660 expected_group_size: Option<u64>,
1661 ) -> Self {
1662 HirRelationExpr::TopK {
1663 input: Box::new(self),
1664 group_key,
1665 order_key,
1666 limit,
1667 offset,
1668 expected_group_size,
1669 }
1670 }
1671
1672 pub fn negate(self) -> Self {
1673 if let HirRelationExpr::Negate { input } = self {
1674 *input
1675 } else {
1676 HirRelationExpr::Negate {
1677 input: Box::new(self),
1678 }
1679 }
1680 }
1681
1682 pub fn distinct(self) -> Self {
1683 if let HirRelationExpr::Distinct { .. } = self {
1684 self
1685 } else {
1686 HirRelationExpr::Distinct {
1687 input: Box::new(self),
1688 }
1689 }
1690 }
1691
1692 pub fn threshold(self) -> Self {
1693 if let HirRelationExpr::Threshold { .. } = self {
1694 self
1695 } else {
1696 HirRelationExpr::Threshold {
1697 input: Box::new(self),
1698 }
1699 }
1700 }
1701
1702 pub fn union(self, other: Self) -> Self {
1703 let mut terms = Vec::new();
1704 if let HirRelationExpr::Union { base, inputs } = self {
1705 terms.push(*base);
1706 terms.extend(inputs);
1707 } else {
1708 terms.push(self);
1709 }
1710 if let HirRelationExpr::Union { base, inputs } = other {
1711 terms.push(*base);
1712 terms.extend(inputs);
1713 } else {
1714 terms.push(other);
1715 }
1716 HirRelationExpr::Union {
1717 base: Box::new(terms.remove(0)),
1718 inputs: terms,
1719 }
1720 }
1721
1722 pub fn exists(self) -> HirScalarExpr {
1723 HirScalarExpr::Exists(Box::new(self))
1724 }
1725
1726 pub fn select(self) -> HirScalarExpr {
1727 HirScalarExpr::Select(Box::new(self))
1728 }
1729
1730 pub fn join(
1731 self,
1732 mut right: HirRelationExpr,
1733 on: HirScalarExpr,
1734 kind: JoinKind,
1735 ) -> HirRelationExpr {
1736 if self.is_join_identity() && !right.is_correlated() && on == HirScalarExpr::literal_true()
1737 {
1738 #[allow(deprecated)]
1742 right.visit_columns_mut(0, &mut |depth, col| {
1743 if col.level > depth {
1744 col.level -= 1;
1745 }
1746 });
1747 right
1748 } else if right.is_join_identity() && on == HirScalarExpr::literal_true() {
1749 self
1750 } else {
1751 HirRelationExpr::Join {
1752 left: Box::new(self),
1753 right: Box::new(right),
1754 on,
1755 kind,
1756 }
1757 }
1758 }
1759
1760 pub fn take(&mut self) -> HirRelationExpr {
1761 mem::replace(
1762 self,
1763 HirRelationExpr::constant(vec![], RelationType::new(Vec::new())),
1764 )
1765 }
1766
1767 #[deprecated = "Use `Visit::visit_post`."]
1768 pub fn visit<'a, F>(&'a self, depth: usize, f: &mut F)
1769 where
1770 F: FnMut(&'a Self, usize),
1771 {
1772 #[allow(deprecated)]
1773 let _ = self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1774 depth: usize|
1775 -> Result<(), ()> {
1776 f(e, depth);
1777 Ok(())
1778 });
1779 }
1780
1781 #[deprecated = "Use `Visit::try_visit_post`."]
1782 pub fn visit_fallible<'a, F, E>(&'a self, depth: usize, f: &mut F) -> Result<(), E>
1783 where
1784 F: FnMut(&'a Self, usize) -> Result<(), E>,
1785 {
1786 #[allow(deprecated)]
1787 self.visit1(depth, |e: &HirRelationExpr, depth: usize| {
1788 e.visit_fallible(depth, f)
1789 })?;
1790 f(self, depth)
1791 }
1792
1793 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_children` instead."]
1794 pub fn visit1<'a, F, E>(&'a self, depth: usize, mut f: F) -> Result<(), E>
1795 where
1796 F: FnMut(&'a Self, usize) -> Result<(), E>,
1797 {
1798 match self {
1799 HirRelationExpr::Constant { .. }
1800 | HirRelationExpr::Get { .. }
1801 | HirRelationExpr::CallTable { .. } => (),
1802 HirRelationExpr::Let { body, value, .. } => {
1803 f(value, depth)?;
1804 f(body, depth)?;
1805 }
1806 HirRelationExpr::LetRec {
1807 limit: _,
1808 bindings,
1809 body,
1810 } => {
1811 for (_, _, value, _) in bindings.iter() {
1812 f(value, depth)?;
1813 }
1814 f(body, depth)?;
1815 }
1816 HirRelationExpr::Project { input, .. } => {
1817 f(input, depth)?;
1818 }
1819 HirRelationExpr::Map { input, .. } => {
1820 f(input, depth)?;
1821 }
1822 HirRelationExpr::Filter { input, .. } => {
1823 f(input, depth)?;
1824 }
1825 HirRelationExpr::Join { left, right, .. } => {
1826 f(left, depth)?;
1827 f(right, depth + 1)?;
1828 }
1829 HirRelationExpr::Reduce { input, .. } => {
1830 f(input, depth)?;
1831 }
1832 HirRelationExpr::Distinct { input } => {
1833 f(input, depth)?;
1834 }
1835 HirRelationExpr::TopK { input, .. } => {
1836 f(input, depth)?;
1837 }
1838 HirRelationExpr::Negate { input } => {
1839 f(input, depth)?;
1840 }
1841 HirRelationExpr::Threshold { input } => {
1842 f(input, depth)?;
1843 }
1844 HirRelationExpr::Union { base, inputs } => {
1845 f(base, depth)?;
1846 for input in inputs {
1847 f(input, depth)?;
1848 }
1849 }
1850 }
1851 Ok(())
1852 }
1853
1854 #[deprecated = "Use `Visit::visit_mut_post` instead."]
1855 pub fn visit_mut<F>(&mut self, depth: usize, f: &mut F)
1856 where
1857 F: FnMut(&mut Self, usize),
1858 {
1859 #[allow(deprecated)]
1860 let _ = self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
1861 depth: usize|
1862 -> Result<(), ()> {
1863 f(e, depth);
1864 Ok(())
1865 });
1866 }
1867
1868 #[deprecated = "Use `Visit::try_visit_mut_post` instead."]
1869 pub fn visit_mut_fallible<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
1870 where
1871 F: FnMut(&mut Self, usize) -> Result<(), E>,
1872 {
1873 #[allow(deprecated)]
1874 self.visit1_mut(depth, |e: &mut HirRelationExpr, depth: usize| {
1875 e.visit_mut_fallible(depth, f)
1876 })?;
1877 f(self, depth)
1878 }
1879
1880 #[deprecated = "Use `VisitChildren<HirRelationExpr>::try_visit_mut_children` instead."]
1881 pub fn visit1_mut<'a, F, E>(&'a mut self, depth: usize, mut f: F) -> Result<(), E>
1882 where
1883 F: FnMut(&'a mut Self, usize) -> Result<(), E>,
1884 {
1885 match self {
1886 HirRelationExpr::Constant { .. }
1887 | HirRelationExpr::Get { .. }
1888 | HirRelationExpr::CallTable { .. } => (),
1889 HirRelationExpr::Let { body, value, .. } => {
1890 f(value, depth)?;
1891 f(body, depth)?;
1892 }
1893 HirRelationExpr::LetRec {
1894 limit: _,
1895 bindings,
1896 body,
1897 } => {
1898 for (_, _, value, _) in bindings.iter_mut() {
1899 f(value, depth)?;
1900 }
1901 f(body, depth)?;
1902 }
1903 HirRelationExpr::Project { input, .. } => {
1904 f(input, depth)?;
1905 }
1906 HirRelationExpr::Map { input, .. } => {
1907 f(input, depth)?;
1908 }
1909 HirRelationExpr::Filter { input, .. } => {
1910 f(input, depth)?;
1911 }
1912 HirRelationExpr::Join { left, right, .. } => {
1913 f(left, depth)?;
1914 f(right, depth + 1)?;
1915 }
1916 HirRelationExpr::Reduce { input, .. } => {
1917 f(input, depth)?;
1918 }
1919 HirRelationExpr::Distinct { input } => {
1920 f(input, depth)?;
1921 }
1922 HirRelationExpr::TopK { input, .. } => {
1923 f(input, depth)?;
1924 }
1925 HirRelationExpr::Negate { input } => {
1926 f(input, depth)?;
1927 }
1928 HirRelationExpr::Threshold { input } => {
1929 f(input, depth)?;
1930 }
1931 HirRelationExpr::Union { base, inputs } => {
1932 f(base, depth)?;
1933 for input in inputs {
1934 f(input, depth)?;
1935 }
1936 }
1937 }
1938 Ok(())
1939 }
1940
1941 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1942 pub fn visit_scalar_expressions<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
1948 where
1949 F: FnMut(&HirScalarExpr, usize) -> Result<(), E>,
1950 {
1951 #[allow(deprecated)]
1952 self.visit_fallible(depth, &mut |e: &HirRelationExpr,
1953 depth: usize|
1954 -> Result<(), E> {
1955 match e {
1956 HirRelationExpr::Join { on, .. } => {
1957 f(on, depth)?;
1958 }
1959 HirRelationExpr::Map { scalars, .. } => {
1960 for scalar in scalars {
1961 f(scalar, depth)?;
1962 }
1963 }
1964 HirRelationExpr::CallTable { exprs, .. } => {
1965 for expr in exprs {
1966 f(expr, depth)?;
1967 }
1968 }
1969 HirRelationExpr::Filter { predicates, .. } => {
1970 for predicate in predicates {
1971 f(predicate, depth)?;
1972 }
1973 }
1974 HirRelationExpr::Reduce { aggregates, .. } => {
1975 for aggregate in aggregates {
1976 f(&aggregate.expr, depth)?;
1977 }
1978 }
1979 HirRelationExpr::TopK { limit, .. } => {
1980 if let Some(limit) = limit {
1981 f(limit, depth)?;
1982 }
1983 }
1984 HirRelationExpr::Union { .. }
1985 | HirRelationExpr::Let { .. }
1986 | HirRelationExpr::LetRec { .. }
1987 | HirRelationExpr::Project { .. }
1988 | HirRelationExpr::Distinct { .. }
1989 | HirRelationExpr::Negate { .. }
1990 | HirRelationExpr::Threshold { .. }
1991 | HirRelationExpr::Constant { .. }
1992 | HirRelationExpr::Get { .. } => (),
1993 }
1994 Ok(())
1995 })
1996 }
1997
1998 #[deprecated = "Use a combination of `Visit` and `VisitChildren` methods."]
1999 pub fn visit_scalar_expressions_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
2001 where
2002 F: FnMut(&mut HirScalarExpr, usize) -> Result<(), E>,
2003 {
2004 #[allow(deprecated)]
2005 self.visit_mut_fallible(depth, &mut |e: &mut HirRelationExpr,
2006 depth: usize|
2007 -> Result<(), E> {
2008 match e {
2009 HirRelationExpr::Join { on, .. } => {
2010 f(on, depth)?;
2011 }
2012 HirRelationExpr::Map { scalars, .. } => {
2013 for scalar in scalars.iter_mut() {
2014 f(scalar, depth)?;
2015 }
2016 }
2017 HirRelationExpr::CallTable { exprs, .. } => {
2018 for expr in exprs.iter_mut() {
2019 f(expr, depth)?;
2020 }
2021 }
2022 HirRelationExpr::Filter { predicates, .. } => {
2023 for predicate in predicates.iter_mut() {
2024 f(predicate, depth)?;
2025 }
2026 }
2027 HirRelationExpr::Reduce { aggregates, .. } => {
2028 for aggregate in aggregates.iter_mut() {
2029 f(&mut aggregate.expr, depth)?;
2030 }
2031 }
2032 HirRelationExpr::TopK { limit, .. } => {
2033 if let Some(limit) = limit {
2034 f(limit, depth)?;
2035 }
2036 }
2037 HirRelationExpr::Union { .. }
2038 | HirRelationExpr::Let { .. }
2039 | HirRelationExpr::LetRec { .. }
2040 | HirRelationExpr::Project { .. }
2041 | HirRelationExpr::Distinct { .. }
2042 | HirRelationExpr::Negate { .. }
2043 | HirRelationExpr::Threshold { .. }
2044 | HirRelationExpr::Constant { .. }
2045 | HirRelationExpr::Get { .. } => (),
2046 }
2047 Ok(())
2048 })
2049 }
2050
2051 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2052 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
2058 where
2059 F: FnMut(usize, &ColumnRef),
2060 {
2061 #[allow(deprecated)]
2062 let _ = self.visit_scalar_expressions(depth, &mut |e: &HirScalarExpr,
2063 depth: usize|
2064 -> Result<(), ()> {
2065 e.visit_columns(depth, f);
2066 Ok(())
2067 });
2068 }
2069
2070 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
2071 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
2073 where
2074 F: FnMut(usize, &mut ColumnRef),
2075 {
2076 #[allow(deprecated)]
2077 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2078 depth: usize|
2079 -> Result<(), ()> {
2080 e.visit_columns_mut(depth, f);
2081 Ok(())
2082 });
2083 }
2084
2085 pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2088 #[allow(deprecated)]
2089 self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2090 e.bind_parameters(params)
2091 })
2092 }
2093
2094 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2096 #[allow(deprecated)]
2097 let _ = self.visit_scalar_expressions_mut(depth, &mut |e: &mut HirScalarExpr,
2098 depth: usize|
2099 -> Result<(), ()> {
2100 e.splice_parameters(params, depth);
2101 Ok(())
2102 });
2103 }
2104
2105 pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
2107 let rows = rows
2108 .into_iter()
2109 .map(move |datums| Row::pack_slice(&datums))
2110 .collect();
2111 HirRelationExpr::Constant { rows, typ }
2112 }
2113
2114 pub fn finish_maintained(
2120 &mut self,
2121 finishing: &mut RowSetFinishing<HirScalarExpr>,
2122 group_size_hints: GroupSizeHints,
2123 ) {
2124 if !finishing.is_trivial(self.arity()) {
2125 let old_finishing =
2126 mem::replace(finishing, RowSetFinishing::trivial(finishing.project.len()));
2127 *self = HirRelationExpr::top_k(
2128 std::mem::replace(
2129 self,
2130 HirRelationExpr::Constant {
2131 rows: vec![],
2132 typ: RelationType::new(Vec::new()),
2133 },
2134 ),
2135 vec![],
2136 old_finishing.order_by,
2137 old_finishing.limit,
2138 old_finishing.offset,
2139 group_size_hints.limit_input_group_size,
2140 )
2141 .project(old_finishing.project);
2142 }
2143 }
2144
2145 pub fn could_run_expensive_function(&self) -> bool {
2154 let mut result = false;
2155 if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| {
2156 use HirRelationExpr::*;
2157 use HirScalarExpr::*;
2158
2159 self.visit_children(|scalar: &HirScalarExpr| {
2160 if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| {
2161 result |= match scalar {
2162 Column(_)
2163 | Literal(_, _)
2164 | CallUnmaterializable(_)
2165 | If { .. }
2166 | Parameter(..)
2167 | Select(..)
2168 | Exists(..) => false,
2169 CallUnary { .. }
2171 | CallBinary { .. }
2172 | CallVariadic { .. }
2173 | Windowing(..) => true,
2174 };
2175 }) {
2176 result = true;
2178 }
2179 });
2180
2181 result |= matches!(e, CallTable { .. } | Reduce { .. });
2184 }) {
2185 result = true;
2187 }
2188
2189 result
2190 }
2191
2192 pub fn contains_temporal(&self) -> Result<bool, RecursionLimitError> {
2194 let mut contains = false;
2195 self.visit_post(&mut |expr| {
2196 expr.visit_children(|expr: &HirScalarExpr| {
2197 contains = contains || expr.contains_temporal()
2198 })
2199 })?;
2200 Ok(contains)
2201 }
2202}
2203
2204impl CollectionPlan for HirRelationExpr {
2205 fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2208 if let Self::Get {
2209 id: Id::Global(id), ..
2210 } = self
2211 {
2212 out.insert(*id);
2213 }
2214 self.visit_children(|expr: &HirRelationExpr| expr.depends_on_into(out))
2215 }
2216}
2217
2218impl VisitChildren<Self> for HirRelationExpr {
2219 fn visit_children<F>(&self, mut f: F)
2220 where
2221 F: FnMut(&Self),
2222 {
2223 VisitChildren::visit_children(self, |expr: &HirScalarExpr| {
2227 #[allow(deprecated)]
2228 Visit::visit_post_nolimit(expr, &mut |expr| match expr {
2229 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_ref()),
2230 _ => (),
2231 });
2232 });
2233
2234 use HirRelationExpr::*;
2235 match self {
2236 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2237 Let {
2238 name: _,
2239 id: _,
2240 value,
2241 body,
2242 } => {
2243 f(value);
2244 f(body);
2245 }
2246 LetRec {
2247 limit: _,
2248 bindings,
2249 body,
2250 } => {
2251 for (_, _, value, _) in bindings.iter() {
2252 f(value);
2253 }
2254 f(body);
2255 }
2256 Project { input, outputs: _ } => f(input),
2257 Map { input, scalars: _ } => {
2258 f(input);
2259 }
2260 CallTable { func: _, exprs: _ } => (),
2261 Filter {
2262 input,
2263 predicates: _,
2264 } => {
2265 f(input);
2266 }
2267 Join {
2268 left,
2269 right,
2270 on: _,
2271 kind: _,
2272 } => {
2273 f(left);
2274 f(right);
2275 }
2276 Reduce {
2277 input,
2278 group_key: _,
2279 aggregates: _,
2280 expected_group_size: _,
2281 } => {
2282 f(input);
2283 }
2284 Distinct { input }
2285 | TopK {
2286 input,
2287 group_key: _,
2288 order_key: _,
2289 limit: _,
2290 offset: _,
2291 expected_group_size: _,
2292 }
2293 | Negate { input }
2294 | Threshold { input } => {
2295 f(input);
2296 }
2297 Union { base, inputs } => {
2298 f(base);
2299 for input in inputs {
2300 f(input);
2301 }
2302 }
2303 }
2304 }
2305
2306 fn visit_mut_children<F>(&mut self, mut f: F)
2307 where
2308 F: FnMut(&mut Self),
2309 {
2310 VisitChildren::visit_mut_children(self, |expr: &mut HirScalarExpr| {
2314 #[allow(deprecated)]
2315 Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
2316 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_mut()),
2317 _ => (),
2318 });
2319 });
2320
2321 use HirRelationExpr::*;
2322 match self {
2323 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2324 Let {
2325 name: _,
2326 id: _,
2327 value,
2328 body,
2329 } => {
2330 f(value);
2331 f(body);
2332 }
2333 LetRec {
2334 limit: _,
2335 bindings,
2336 body,
2337 } => {
2338 for (_, _, value, _) in bindings.iter_mut() {
2339 f(value);
2340 }
2341 f(body);
2342 }
2343 Project { input, outputs: _ } => f(input),
2344 Map { input, scalars: _ } => {
2345 f(input);
2346 }
2347 CallTable { func: _, exprs: _ } => (),
2348 Filter {
2349 input,
2350 predicates: _,
2351 } => {
2352 f(input);
2353 }
2354 Join {
2355 left,
2356 right,
2357 on: _,
2358 kind: _,
2359 } => {
2360 f(left);
2361 f(right);
2362 }
2363 Reduce {
2364 input,
2365 group_key: _,
2366 aggregates: _,
2367 expected_group_size: _,
2368 } => {
2369 f(input);
2370 }
2371 Distinct { input }
2372 | TopK {
2373 input,
2374 group_key: _,
2375 order_key: _,
2376 limit: _,
2377 offset: _,
2378 expected_group_size: _,
2379 }
2380 | Negate { input }
2381 | Threshold { input } => {
2382 f(input);
2383 }
2384 Union { base, inputs } => {
2385 f(base);
2386 for input in inputs {
2387 f(input);
2388 }
2389 }
2390 }
2391 }
2392
2393 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2394 where
2395 F: FnMut(&Self) -> Result<(), E>,
2396 E: From<RecursionLimitError>,
2397 {
2398 VisitChildren::try_visit_children(self, |expr: &HirScalarExpr| {
2402 Visit::try_visit_post(expr, &mut |expr| match expr {
2403 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_ref()),
2404 _ => Ok(()),
2405 })
2406 })?;
2407
2408 use HirRelationExpr::*;
2409 match self {
2410 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2411 Let {
2412 name: _,
2413 id: _,
2414 value,
2415 body,
2416 } => {
2417 f(value)?;
2418 f(body)?;
2419 }
2420 LetRec {
2421 limit: _,
2422 bindings,
2423 body,
2424 } => {
2425 for (_, _, value, _) in bindings.iter() {
2426 f(value)?;
2427 }
2428 f(body)?;
2429 }
2430 Project { input, outputs: _ } => f(input)?,
2431 Map { input, scalars: _ } => {
2432 f(input)?;
2433 }
2434 CallTable { func: _, exprs: _ } => (),
2435 Filter {
2436 input,
2437 predicates: _,
2438 } => {
2439 f(input)?;
2440 }
2441 Join {
2442 left,
2443 right,
2444 on: _,
2445 kind: _,
2446 } => {
2447 f(left)?;
2448 f(right)?;
2449 }
2450 Reduce {
2451 input,
2452 group_key: _,
2453 aggregates: _,
2454 expected_group_size: _,
2455 } => {
2456 f(input)?;
2457 }
2458 Distinct { input }
2459 | TopK {
2460 input,
2461 group_key: _,
2462 order_key: _,
2463 limit: _,
2464 offset: _,
2465 expected_group_size: _,
2466 }
2467 | Negate { input }
2468 | Threshold { input } => {
2469 f(input)?;
2470 }
2471 Union { base, inputs } => {
2472 f(base)?;
2473 for input in inputs {
2474 f(input)?;
2475 }
2476 }
2477 }
2478 Ok(())
2479 }
2480
2481 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2482 where
2483 F: FnMut(&mut Self) -> Result<(), E>,
2484 E: From<RecursionLimitError>,
2485 {
2486 VisitChildren::try_visit_mut_children(self, |expr: &mut HirScalarExpr| {
2490 Visit::try_visit_mut_post(expr, &mut |expr| match expr {
2491 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => f(expr.as_mut()),
2492 _ => Ok(()),
2493 })
2494 })?;
2495
2496 use HirRelationExpr::*;
2497 match self {
2498 Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (),
2499 Let {
2500 name: _,
2501 id: _,
2502 value,
2503 body,
2504 } => {
2505 f(value)?;
2506 f(body)?;
2507 }
2508 LetRec {
2509 limit: _,
2510 bindings,
2511 body,
2512 } => {
2513 for (_, _, value, _) in bindings.iter_mut() {
2514 f(value)?;
2515 }
2516 f(body)?;
2517 }
2518 Project { input, outputs: _ } => f(input)?,
2519 Map { input, scalars: _ } => {
2520 f(input)?;
2521 }
2522 CallTable { func: _, exprs: _ } => (),
2523 Filter {
2524 input,
2525 predicates: _,
2526 } => {
2527 f(input)?;
2528 }
2529 Join {
2530 left,
2531 right,
2532 on: _,
2533 kind: _,
2534 } => {
2535 f(left)?;
2536 f(right)?;
2537 }
2538 Reduce {
2539 input,
2540 group_key: _,
2541 aggregates: _,
2542 expected_group_size: _,
2543 } => {
2544 f(input)?;
2545 }
2546 Distinct { input }
2547 | TopK {
2548 input,
2549 group_key: _,
2550 order_key: _,
2551 limit: _,
2552 offset: _,
2553 expected_group_size: _,
2554 }
2555 | Negate { input }
2556 | Threshold { input } => {
2557 f(input)?;
2558 }
2559 Union { base, inputs } => {
2560 f(base)?;
2561 for input in inputs {
2562 f(input)?;
2563 }
2564 }
2565 }
2566 Ok(())
2567 }
2568}
2569
2570impl VisitChildren<HirScalarExpr> for HirRelationExpr {
2571 fn visit_children<F>(&self, mut f: F)
2572 where
2573 F: FnMut(&HirScalarExpr),
2574 {
2575 use HirRelationExpr::*;
2576 match self {
2577 Constant { rows: _, typ: _ }
2578 | Get { id: _, typ: _ }
2579 | Let {
2580 name: _,
2581 id: _,
2582 value: _,
2583 body: _,
2584 }
2585 | LetRec {
2586 limit: _,
2587 bindings: _,
2588 body: _,
2589 }
2590 | Project {
2591 input: _,
2592 outputs: _,
2593 } => (),
2594 Map { input: _, scalars } => {
2595 for scalar in scalars {
2596 f(scalar);
2597 }
2598 }
2599 CallTable { func: _, exprs } => {
2600 for expr in exprs {
2601 f(expr);
2602 }
2603 }
2604 Filter {
2605 input: _,
2606 predicates,
2607 } => {
2608 for predicate in predicates {
2609 f(predicate);
2610 }
2611 }
2612 Join {
2613 left: _,
2614 right: _,
2615 on,
2616 kind: _,
2617 } => f(on),
2618 Reduce {
2619 input: _,
2620 group_key: _,
2621 aggregates,
2622 expected_group_size: _,
2623 } => {
2624 for aggregate in aggregates {
2625 f(aggregate.expr.as_ref());
2626 }
2627 }
2628 TopK {
2629 input: _,
2630 group_key: _,
2631 order_key: _,
2632 limit,
2633 offset: _,
2634 expected_group_size: _,
2635 } => {
2636 if let Some(limit) = limit {
2637 f(limit)
2638 }
2639 }
2640 Distinct { input: _ }
2641 | Negate { input: _ }
2642 | Threshold { input: _ }
2643 | Union { base: _, inputs: _ } => (),
2644 }
2645 }
2646
2647 fn visit_mut_children<F>(&mut self, mut f: F)
2648 where
2649 F: FnMut(&mut HirScalarExpr),
2650 {
2651 use HirRelationExpr::*;
2652 match self {
2653 Constant { rows: _, typ: _ }
2654 | Get { id: _, typ: _ }
2655 | Let {
2656 name: _,
2657 id: _,
2658 value: _,
2659 body: _,
2660 }
2661 | LetRec {
2662 limit: _,
2663 bindings: _,
2664 body: _,
2665 }
2666 | Project {
2667 input: _,
2668 outputs: _,
2669 } => (),
2670 Map { input: _, scalars } => {
2671 for scalar in scalars {
2672 f(scalar);
2673 }
2674 }
2675 CallTable { func: _, exprs } => {
2676 for expr in exprs {
2677 f(expr);
2678 }
2679 }
2680 Filter {
2681 input: _,
2682 predicates,
2683 } => {
2684 for predicate in predicates {
2685 f(predicate);
2686 }
2687 }
2688 Join {
2689 left: _,
2690 right: _,
2691 on,
2692 kind: _,
2693 } => f(on),
2694 Reduce {
2695 input: _,
2696 group_key: _,
2697 aggregates,
2698 expected_group_size: _,
2699 } => {
2700 for aggregate in aggregates {
2701 f(aggregate.expr.as_mut());
2702 }
2703 }
2704 Distinct { input: _ }
2705 | TopK {
2706 input: _,
2707 group_key: _,
2708 order_key: _,
2709 limit: _,
2710 offset: _,
2711 expected_group_size: _,
2712 }
2713 | Negate { input: _ }
2714 | Threshold { input: _ }
2715 | Union { base: _, inputs: _ } => (),
2716 }
2717 }
2718
2719 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2720 where
2721 F: FnMut(&HirScalarExpr) -> Result<(), E>,
2722 E: From<RecursionLimitError>,
2723 {
2724 use HirRelationExpr::*;
2725 match self {
2726 Constant { rows: _, typ: _ }
2727 | Get { id: _, typ: _ }
2728 | Let {
2729 name: _,
2730 id: _,
2731 value: _,
2732 body: _,
2733 }
2734 | LetRec {
2735 limit: _,
2736 bindings: _,
2737 body: _,
2738 }
2739 | Project {
2740 input: _,
2741 outputs: _,
2742 } => (),
2743 Map { input: _, scalars } => {
2744 for scalar in scalars {
2745 f(scalar)?;
2746 }
2747 }
2748 CallTable { func: _, exprs } => {
2749 for expr in exprs {
2750 f(expr)?;
2751 }
2752 }
2753 Filter {
2754 input: _,
2755 predicates,
2756 } => {
2757 for predicate in predicates {
2758 f(predicate)?;
2759 }
2760 }
2761 Join {
2762 left: _,
2763 right: _,
2764 on,
2765 kind: _,
2766 } => f(on)?,
2767 Reduce {
2768 input: _,
2769 group_key: _,
2770 aggregates,
2771 expected_group_size: _,
2772 } => {
2773 for aggregate in aggregates {
2774 f(aggregate.expr.as_ref())?;
2775 }
2776 }
2777 Distinct { input: _ }
2778 | TopK {
2779 input: _,
2780 group_key: _,
2781 order_key: _,
2782 limit: _,
2783 offset: _,
2784 expected_group_size: _,
2785 }
2786 | Negate { input: _ }
2787 | Threshold { input: _ }
2788 | Union { base: _, inputs: _ } => (),
2789 }
2790 Ok(())
2791 }
2792
2793 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2794 where
2795 F: FnMut(&mut HirScalarExpr) -> Result<(), E>,
2796 E: From<RecursionLimitError>,
2797 {
2798 use HirRelationExpr::*;
2799 match self {
2800 Constant { rows: _, typ: _ }
2801 | Get { id: _, typ: _ }
2802 | Let {
2803 name: _,
2804 id: _,
2805 value: _,
2806 body: _,
2807 }
2808 | LetRec {
2809 limit: _,
2810 bindings: _,
2811 body: _,
2812 }
2813 | Project {
2814 input: _,
2815 outputs: _,
2816 } => (),
2817 Map { input: _, scalars } => {
2818 for scalar in scalars {
2819 f(scalar)?;
2820 }
2821 }
2822 CallTable { func: _, exprs } => {
2823 for expr in exprs {
2824 f(expr)?;
2825 }
2826 }
2827 Filter {
2828 input: _,
2829 predicates,
2830 } => {
2831 for predicate in predicates {
2832 f(predicate)?;
2833 }
2834 }
2835 Join {
2836 left: _,
2837 right: _,
2838 on,
2839 kind: _,
2840 } => f(on)?,
2841 Reduce {
2842 input: _,
2843 group_key: _,
2844 aggregates,
2845 expected_group_size: _,
2846 } => {
2847 for aggregate in aggregates {
2848 f(aggregate.expr.as_mut())?;
2849 }
2850 }
2851 Distinct { input: _ }
2852 | TopK {
2853 input: _,
2854 group_key: _,
2855 order_key: _,
2856 limit: _,
2857 offset: _,
2858 expected_group_size: _,
2859 }
2860 | Negate { input: _ }
2861 | Threshold { input: _ }
2862 | Union { base: _, inputs: _ } => (),
2863 }
2864 Ok(())
2865 }
2866}
2867
2868impl HirScalarExpr {
2869 pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2872 #[allow(deprecated)]
2873 self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
2874 if let HirScalarExpr::Parameter(n) = e {
2875 let datum = match params.datums.iter().nth(*n - 1) {
2876 None => sql_bail!("there is no parameter ${}", n),
2877 Some(datum) => datum,
2878 };
2879 let scalar_type = ¶ms.types[*n - 1];
2880 let row = Row::pack([datum]);
2881 let column_type = scalar_type.clone().nullable(datum.is_null());
2882 *e = HirScalarExpr::Literal(row, column_type);
2883 }
2884 Ok(())
2885 })
2886 }
2887
2888 pub fn splice_parameters(&mut self, params: &[HirScalarExpr], depth: usize) {
2899 #[allow(deprecated)]
2900 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
2901 e: &mut HirScalarExpr|
2902 -> Result<(), ()> {
2903 if let HirScalarExpr::Parameter(i) = e {
2904 *e = params[*i - 1].clone();
2905 e.visit_columns_mut(0, &mut |d: usize, col: &mut ColumnRef| {
2908 if col.level >= d {
2909 col.level += depth
2910 }
2911 });
2912 }
2913 Ok(())
2914 });
2915 }
2916
2917 pub fn contains_temporal(&self) -> bool {
2919 let mut contains = false;
2920 #[allow(deprecated)]
2921 self.visit_post_nolimit(&mut |e| {
2922 if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow) = e {
2923 contains = true;
2924 }
2925 });
2926 contains
2927 }
2928
2929 pub fn column(index: usize) -> HirScalarExpr {
2931 HirScalarExpr::Column(ColumnRef {
2932 level: 0,
2933 column: index,
2934 })
2935 }
2936
2937 pub fn literal(datum: Datum, scalar_type: ScalarType) -> HirScalarExpr {
2938 let row = Row::pack([datum]);
2939 HirScalarExpr::Literal(row, scalar_type.nullable(datum.is_null()))
2940 }
2941
2942 pub fn literal_true() -> HirScalarExpr {
2943 HirScalarExpr::literal(Datum::True, ScalarType::Bool)
2944 }
2945
2946 pub fn literal_false() -> HirScalarExpr {
2947 HirScalarExpr::literal(Datum::False, ScalarType::Bool)
2948 }
2949
2950 pub fn literal_null(scalar_type: ScalarType) -> HirScalarExpr {
2951 HirScalarExpr::literal(Datum::Null, scalar_type)
2952 }
2953
2954 pub fn literal_1d_array(
2955 datums: Vec<Datum>,
2956 element_scalar_type: ScalarType,
2957 ) -> Result<HirScalarExpr, PlanError> {
2958 let scalar_type = match element_scalar_type {
2959 ScalarType::Array(_) => {
2960 sql_bail!("cannot build array from array type");
2961 }
2962 typ => ScalarType::Array(Box::new(typ)).nullable(false),
2963 };
2964
2965 let mut row = Row::default();
2966 row.packer()
2967 .try_push_array(
2968 &[ArrayDimension {
2969 lower_bound: 1,
2970 length: datums.len(),
2971 }],
2972 datums,
2973 )
2974 .expect("array constructed to be valid");
2975
2976 Ok(HirScalarExpr::Literal(row, scalar_type))
2977 }
2978
2979 pub fn as_literal(&self) -> Option<Datum> {
2980 if let HirScalarExpr::Literal(row, _column_type) = self {
2981 Some(row.unpack_first())
2982 } else {
2983 None
2984 }
2985 }
2986
2987 pub fn is_literal_true(&self) -> bool {
2988 Some(Datum::True) == self.as_literal()
2989 }
2990
2991 pub fn is_literal_false(&self) -> bool {
2992 Some(Datum::False) == self.as_literal()
2993 }
2994
2995 pub fn is_literal_null(&self) -> bool {
2996 Some(Datum::Null) == self.as_literal()
2997 }
2998
2999 pub fn is_constant(&self) -> bool {
3002 let mut worklist = vec![self];
3003 while let Some(expr) = worklist.pop() {
3004 match expr {
3005 Self::Literal(_, _) => {
3006 }
3008 Self::CallUnary { expr, .. } => {
3009 worklist.push(expr);
3010 }
3011 Self::CallBinary {
3012 func: _,
3013 expr1,
3014 expr2,
3015 } => {
3016 worklist.push(expr1);
3017 worklist.push(expr2);
3018 }
3019 Self::CallVariadic { func: _, exprs } => {
3020 worklist.extend(exprs.iter());
3021 }
3022 Self::If { cond, then, els } => {
3023 worklist.push(cond);
3024 worklist.push(then);
3025 worklist.push(els);
3026 }
3027 _ => {
3028 return false; }
3030 }
3031 }
3032 true
3033 }
3034
3035 pub fn call_unary(self, func: UnaryFunc) -> Self {
3036 HirScalarExpr::CallUnary {
3037 func,
3038 expr: Box::new(self),
3039 }
3040 }
3041
3042 pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
3043 HirScalarExpr::CallBinary {
3044 func,
3045 expr1: Box::new(self),
3046 expr2: Box::new(other),
3047 }
3048 }
3049
3050 pub fn or(self, other: Self) -> Self {
3051 HirScalarExpr::CallVariadic {
3052 func: VariadicFunc::Or,
3053 exprs: vec![self, other],
3054 }
3055 }
3056
3057 pub fn and(self, other: Self) -> Self {
3058 HirScalarExpr::CallVariadic {
3059 func: VariadicFunc::And,
3060 exprs: vec![self, other],
3061 }
3062 }
3063
3064 pub fn not(self) -> Self {
3065 self.call_unary(UnaryFunc::Not(func::Not))
3066 }
3067
3068 pub fn call_is_null(self) -> Self {
3069 self.call_unary(UnaryFunc::IsNull(func::IsNull))
3070 }
3071
3072 pub fn variadic_and(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3074 match args.len() {
3075 0 => HirScalarExpr::literal_true(), 1 => args.swap_remove(0),
3077 _ => HirScalarExpr::CallVariadic {
3078 func: VariadicFunc::And,
3079 exprs: args,
3080 },
3081 }
3082 }
3083
3084 pub fn variadic_or(mut args: Vec<HirScalarExpr>) -> HirScalarExpr {
3086 match args.len() {
3087 0 => HirScalarExpr::literal_false(), 1 => args.swap_remove(0),
3089 _ => HirScalarExpr::CallVariadic {
3090 func: VariadicFunc::Or,
3091 exprs: args,
3092 },
3093 }
3094 }
3095
3096 pub fn take(&mut self) -> Self {
3097 mem::replace(self, HirScalarExpr::literal_null(ScalarType::String))
3098 }
3099
3100 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3101 pub fn visit_columns<F>(&self, depth: usize, f: &mut F)
3107 where
3108 F: FnMut(usize, &ColumnRef),
3109 {
3110 #[allow(deprecated)]
3111 let _ = self.visit_recursively(depth, &mut |depth: usize,
3112 e: &HirScalarExpr|
3113 -> Result<(), ()> {
3114 if let HirScalarExpr::Column(col) = e {
3115 f(depth, col)
3116 }
3117 Ok(())
3118 });
3119 }
3120
3121 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3122 pub fn visit_columns_mut<F>(&mut self, depth: usize, f: &mut F)
3124 where
3125 F: FnMut(usize, &mut ColumnRef),
3126 {
3127 #[allow(deprecated)]
3128 let _ = self.visit_recursively_mut(depth, &mut |depth: usize,
3129 e: &mut HirScalarExpr|
3130 -> Result<(), ()> {
3131 if let HirScalarExpr::Column(col) = e {
3132 f(depth, col)
3133 }
3134 Ok(())
3135 });
3136 }
3137
3138 pub fn visit_columns_referring_to_root_level<F>(&self, f: &mut F)
3144 where
3145 F: FnMut(usize),
3146 {
3147 #[allow(deprecated)]
3148 let _ = self.visit_recursively(0, &mut |depth: usize,
3149 e: &HirScalarExpr|
3150 -> Result<(), ()> {
3151 if let HirScalarExpr::Column(col) = e {
3152 if col.level == depth {
3153 f(col.column)
3154 }
3155 }
3156 Ok(())
3157 });
3158 }
3159
3160 pub fn visit_columns_referring_to_root_level_mut<F>(&mut self, f: &mut F)
3162 where
3163 F: FnMut(&mut usize),
3164 {
3165 #[allow(deprecated)]
3166 let _ = self.visit_recursively_mut(0, &mut |depth: usize,
3167 e: &mut HirScalarExpr|
3168 -> Result<(), ()> {
3169 if let HirScalarExpr::Column(col) = e {
3170 if col.level == depth {
3171 f(&mut col.column)
3172 }
3173 }
3174 Ok(())
3175 });
3176 }
3177
3178 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3179 pub fn visit_recursively<F, E>(&self, depth: usize, f: &mut F) -> Result<(), E>
3183 where
3184 F: FnMut(usize, &HirScalarExpr) -> Result<(), E>,
3185 {
3186 match self {
3187 HirScalarExpr::Literal(_, _)
3188 | HirScalarExpr::Parameter(_)
3189 | HirScalarExpr::CallUnmaterializable(_)
3190 | HirScalarExpr::Column(_) => (),
3191 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively(depth, f)?,
3192 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3193 expr1.visit_recursively(depth, f)?;
3194 expr2.visit_recursively(depth, f)?;
3195 }
3196 HirScalarExpr::CallVariadic { exprs, .. } => {
3197 for expr in exprs {
3198 expr.visit_recursively(depth, f)?;
3199 }
3200 }
3201 HirScalarExpr::If { cond, then, els } => {
3202 cond.visit_recursively(depth, f)?;
3203 then.visit_recursively(depth, f)?;
3204 els.visit_recursively(depth, f)?;
3205 }
3206 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => {
3207 #[allow(deprecated)]
3208 expr.visit_scalar_expressions(depth + 1, &mut |e, depth| {
3209 e.visit_recursively(depth, f)
3210 })?;
3211 }
3212 HirScalarExpr::Windowing(expr) => {
3213 expr.visit_expressions(&mut |e| e.visit_recursively(depth, f))?;
3214 }
3215 }
3216 f(depth, self)
3217 }
3218
3219 #[deprecated = "Redefine this based on the `Visit` and `VisitChildren` methods."]
3220 pub fn visit_recursively_mut<F, E>(&mut self, depth: usize, f: &mut F) -> Result<(), E>
3222 where
3223 F: FnMut(usize, &mut HirScalarExpr) -> Result<(), E>,
3224 {
3225 match self {
3226 HirScalarExpr::Literal(_, _)
3227 | HirScalarExpr::Parameter(_)
3228 | HirScalarExpr::CallUnmaterializable(_)
3229 | HirScalarExpr::Column(_) => (),
3230 HirScalarExpr::CallUnary { expr, .. } => expr.visit_recursively_mut(depth, f)?,
3231 HirScalarExpr::CallBinary { expr1, expr2, .. } => {
3232 expr1.visit_recursively_mut(depth, f)?;
3233 expr2.visit_recursively_mut(depth, f)?;
3234 }
3235 HirScalarExpr::CallVariadic { exprs, .. } => {
3236 for expr in exprs {
3237 expr.visit_recursively_mut(depth, f)?;
3238 }
3239 }
3240 HirScalarExpr::If { cond, then, els } => {
3241 cond.visit_recursively_mut(depth, f)?;
3242 then.visit_recursively_mut(depth, f)?;
3243 els.visit_recursively_mut(depth, f)?;
3244 }
3245 HirScalarExpr::Exists(expr) | HirScalarExpr::Select(expr) => {
3246 #[allow(deprecated)]
3247 expr.visit_scalar_expressions_mut(depth + 1, &mut |e, depth| {
3248 e.visit_recursively_mut(depth, f)
3249 })?;
3250 }
3251 HirScalarExpr::Windowing(expr) => {
3252 expr.visit_expressions_mut(&mut |e| e.visit_recursively_mut(depth, f))?;
3253 }
3254 }
3255 f(depth, self)
3256 }
3257
3258 fn simplify_to_literal(self) -> Option<Row> {
3259 let mut expr = self.lower_uncorrelated().ok()?;
3260 expr.reduce(&[]);
3261 match expr {
3262 mz_expr::MirScalarExpr::Literal(Ok(row), _) => Some(row),
3263 _ => None,
3264 }
3265 }
3266
3267 pub fn into_literal_int64(self) -> Option<i64> {
3276 self.simplify_to_literal().and_then(|row| {
3277 let datum = row.unpack_first();
3278 if datum.is_null() {
3279 None
3280 } else {
3281 Some(datum.unwrap_int64())
3282 }
3283 })
3284 }
3285
3286 pub fn into_literal_string(self) -> Option<String> {
3295 self.simplify_to_literal().and_then(|row| {
3296 let datum = row.unpack_first();
3297 if datum.is_null() {
3298 None
3299 } else {
3300 Some(datum.unwrap_str().to_owned())
3301 }
3302 })
3303 }
3304
3305 pub fn into_literal_mz_timestamp(self) -> Option<Timestamp> {
3317 self.simplify_to_literal().and_then(|row| {
3318 let datum = row.unpack_first();
3319 if datum.is_null() {
3320 None
3321 } else {
3322 Some(datum.unwrap_mz_timestamp())
3323 }
3324 })
3325 }
3326}
3327
3328impl VisitChildren<Self> for HirScalarExpr {
3329 fn visit_children<F>(&self, mut f: F)
3330 where
3331 F: FnMut(&Self),
3332 {
3333 use HirScalarExpr::*;
3334 match self {
3335 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3336 CallUnary { expr, .. } => f(expr),
3337 CallBinary { expr1, expr2, .. } => {
3338 f(expr1);
3339 f(expr2);
3340 }
3341 CallVariadic { exprs, .. } => {
3342 for expr in exprs {
3343 f(expr);
3344 }
3345 }
3346 If { cond, then, els } => {
3347 f(cond);
3348 f(then);
3349 f(els);
3350 }
3351 Exists(..) | Select(..) => (),
3352 Windowing(expr) => expr.visit_children(f),
3353 }
3354 }
3355
3356 fn visit_mut_children<F>(&mut self, mut f: F)
3357 where
3358 F: FnMut(&mut Self),
3359 {
3360 use HirScalarExpr::*;
3361 match self {
3362 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3363 CallUnary { expr, .. } => f(expr),
3364 CallBinary { expr1, expr2, .. } => {
3365 f(expr1);
3366 f(expr2);
3367 }
3368 CallVariadic { exprs, .. } => {
3369 for expr in exprs {
3370 f(expr);
3371 }
3372 }
3373 If { cond, then, els } => {
3374 f(cond);
3375 f(then);
3376 f(els);
3377 }
3378 Exists(..) | Select(..) => (),
3379 Windowing(expr) => expr.visit_mut_children(f),
3380 }
3381 }
3382
3383 fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
3384 where
3385 F: FnMut(&Self) -> Result<(), E>,
3386 E: From<RecursionLimitError>,
3387 {
3388 use HirScalarExpr::*;
3389 match self {
3390 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3391 CallUnary { expr, .. } => f(expr)?,
3392 CallBinary { expr1, expr2, .. } => {
3393 f(expr1)?;
3394 f(expr2)?;
3395 }
3396 CallVariadic { exprs, .. } => {
3397 for expr in exprs {
3398 f(expr)?;
3399 }
3400 }
3401 If { cond, then, els } => {
3402 f(cond)?;
3403 f(then)?;
3404 f(els)?;
3405 }
3406 Exists(..) | Select(..) => (),
3407 Windowing(expr) => expr.try_visit_children(f)?,
3408 }
3409 Ok(())
3410 }
3411
3412 fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
3413 where
3414 F: FnMut(&mut Self) -> Result<(), E>,
3415 E: From<RecursionLimitError>,
3416 {
3417 use HirScalarExpr::*;
3418 match self {
3419 Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => (),
3420 CallUnary { expr, .. } => f(expr)?,
3421 CallBinary { expr1, expr2, .. } => {
3422 f(expr1)?;
3423 f(expr2)?;
3424 }
3425 CallVariadic { exprs, .. } => {
3426 for expr in exprs {
3427 f(expr)?;
3428 }
3429 }
3430 If { cond, then, els } => {
3431 f(cond)?;
3432 f(then)?;
3433 f(els)?;
3434 }
3435 Exists(..) | Select(..) => (),
3436 Windowing(expr) => expr.try_visit_mut_children(f)?,
3437 }
3438 Ok(())
3439 }
3440}
3441
3442impl AbstractExpr for HirScalarExpr {
3443 type Type = ColumnType;
3444
3445 fn typ(
3446 &self,
3447 outers: &[RelationType],
3448 inner: &RelationType,
3449 params: &BTreeMap<usize, ScalarType>,
3450 ) -> Self::Type {
3451 stack::maybe_grow(|| match self {
3452 HirScalarExpr::Column(ColumnRef { level, column }) => {
3453 if *level == 0 {
3454 inner.column_types[*column].clone()
3455 } else {
3456 outers[*level - 1].column_types[*column].clone()
3457 }
3458 }
3459 HirScalarExpr::Parameter(n) => params[n].clone().nullable(true),
3460 HirScalarExpr::Literal(_, typ) => typ.clone(),
3461 HirScalarExpr::CallUnmaterializable(func) => func.output_type(),
3462 HirScalarExpr::CallUnary { expr, func } => {
3463 func.output_type(expr.typ(outers, inner, params))
3464 }
3465 HirScalarExpr::CallBinary { expr1, expr2, func } => func.output_type(
3466 expr1.typ(outers, inner, params),
3467 expr2.typ(outers, inner, params),
3468 ),
3469 HirScalarExpr::CallVariadic { exprs, func } => {
3470 func.output_type(exprs.iter().map(|e| e.typ(outers, inner, params)).collect())
3471 }
3472 HirScalarExpr::If { cond: _, then, els } => {
3473 let then_type = then.typ(outers, inner, params);
3474 let else_type = els.typ(outers, inner, params);
3475 then_type.union(&else_type).unwrap()
3476 }
3477 HirScalarExpr::Exists(_) => ScalarType::Bool.nullable(true),
3478 HirScalarExpr::Select(expr) => {
3479 let mut outers = outers.to_vec();
3480 outers.insert(0, inner.clone());
3481 expr.typ(&outers, params)
3482 .column_types
3483 .into_element()
3484 .nullable(true)
3485 }
3486 HirScalarExpr::Windowing(expr) => expr.func.typ(outers, inner, params),
3487 })
3488 }
3489}
3490
3491impl AggregateExpr {
3492 pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
3495 self.expr.bind_parameters(params)
3496 }
3497
3498 pub fn typ(
3499 &self,
3500 outers: &[RelationType],
3501 inner: &RelationType,
3502 params: &BTreeMap<usize, ScalarType>,
3503 ) -> ColumnType {
3504 self.func.output_type(self.expr.typ(outers, inner, params))
3505 }
3506
3507 pub fn is_count_asterisk(&self) -> bool {
3515 self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3516 }
3517}