1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
23use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
24use mz_repr::{ColumnName, ColumnType, RelationType, Row, ScalarBaseType, ScalarType};
25
26pub type SharedContext = Arc<Mutex<Context>>;
31
32pub fn empty_context() -> SharedContext {
34 Arc::new(Mutex::new(BTreeMap::new()))
35}
36
37#[derive(Clone, Debug)]
42pub enum TypeError<'a> {
43 Unbound {
45 source: &'a MirRelationExpr,
47 id: Id,
49 typ: RelationType,
51 },
52 NoSuchColumn {
54 source: &'a MirRelationExpr,
56 expr: &'a MirScalarExpr,
58 col: usize,
60 },
61 MismatchColumn {
63 source: &'a MirRelationExpr,
65 got: ColumnType,
67 expected: ColumnType,
69 diffs: Vec<ColumnTypeDifference>,
71 message: String,
73 },
74 MismatchColumns {
76 source: &'a MirRelationExpr,
78 got: Vec<ColumnType>,
80 expected: Vec<ColumnType>,
82 diffs: Vec<RelationTypeDifference>,
84 message: String,
86 },
87 BadConstantRow {
89 source: &'a MirRelationExpr,
91 got: Row,
93 expected: Vec<ColumnType>,
95 },
97 BadProject {
99 source: &'a MirRelationExpr,
101 got: Vec<usize>,
103 input_type: Vec<ColumnType>,
105 },
106 BadJoinEquivalence {
108 source: &'a MirRelationExpr,
110 got: Vec<ColumnType>,
112 message: String,
114 },
115 BadTopKGroupKey {
117 source: &'a MirRelationExpr,
119 k: usize,
121 input_type: Vec<ColumnType>,
123 },
124 BadTopKOrdering {
126 source: &'a MirRelationExpr,
128 order: ColumnOrder,
130 input_type: Vec<ColumnType>,
132 },
133 BadLetRecBindings {
135 source: &'a MirRelationExpr,
137 },
138 Shadowing {
140 source: &'a MirRelationExpr,
142 id: Id,
144 },
145 Recursion {
147 error: RecursionLimitError,
149 },
150 DisallowedDummy {
152 source: &'a MirRelationExpr,
154 },
155}
156
157impl<'a> From<RecursionLimitError> for TypeError<'a> {
158 fn from(error: RecursionLimitError) -> Self {
159 TypeError::Recursion { error }
160 }
161}
162
163type Context = BTreeMap<Id, Vec<ColumnType>>;
164
165#[derive(Clone, Debug, Hash)]
169pub enum RelationTypeDifference {
170 Length {
172 len_sub: usize,
174 len_sup: usize,
176 },
177 Column {
179 col: usize,
181 diff: ColumnTypeDifference,
183 },
184}
185
186#[derive(Clone, Debug, Hash)]
191pub enum ColumnTypeDifference {
192 NotSubtype {
194 sub: ScalarType,
196 sup: ScalarType,
198 },
199 Nullability {
201 sub: ColumnType,
203 sup: ColumnType,
205 },
206 ElementType {
208 ctor: String,
210 element_type: Box<ColumnTypeDifference>,
212 },
213 RecordMissingFields {
215 missing: Vec<ColumnName>,
217 },
218 RecordFields {
220 fields: Vec<(ColumnName, ColumnTypeDifference)>,
222 },
223}
224
225impl RelationTypeDifference {
226 pub fn ignore_nullability(self) -> Option<Self> {
230 use RelationTypeDifference::*;
231
232 match self {
233 Length { .. } => Some(self),
234 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
235 }
236 }
237}
238
239impl ColumnTypeDifference {
240 pub fn ignore_nullability(self) -> Option<Self> {
244 use ColumnTypeDifference::*;
245
246 match self {
247 Nullability { .. } => None,
248 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
249 ElementType { ctor, element_type } => {
250 element_type
251 .ignore_nullability()
252 .map(|element_type| ElementType {
253 ctor,
254 element_type: Box::new(element_type),
255 })
256 }
257 RecordFields { fields } => {
258 let fields = fields
259 .into_iter()
260 .flat_map(|(col, diff)| diff.ignore_nullability().map(|diff| (col, diff)))
261 .collect::<Vec<_>>();
262
263 if fields.is_empty() {
264 None
265 } else {
266 Some(RecordFields { fields })
267 }
268 }
269 }
270 }
271}
272
273pub fn relation_subtype_difference(
277 sub: &[ColumnType],
278 sup: &[ColumnType],
279) -> Vec<RelationTypeDifference> {
280 let mut diffs = Vec::new();
281
282 if sub.len() != sup.len() {
283 diffs.push(RelationTypeDifference::Length {
284 len_sub: sub.len(),
285 len_sup: sup.len(),
286 });
287
288 return diffs;
290 }
291
292 diffs.extend(
293 sub.iter()
294 .zip_eq(sup.iter())
295 .enumerate()
296 .flat_map(|(col, (sub_ty, sup_ty))| {
297 column_subtype_difference(sub_ty, sup_ty)
298 .into_iter()
299 .map(move |diff| RelationTypeDifference::Column { col, diff })
300 }),
301 );
302
303 diffs
304}
305
306pub fn column_subtype_difference(sub: &ColumnType, sup: &ColumnType) -> Vec<ColumnTypeDifference> {
310 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
311
312 if sub.nullable && !sup.nullable {
313 diffs.push(ColumnTypeDifference::Nullability {
314 sub: sub.clone(),
315 sup: sup.clone(),
316 });
317 }
318
319 diffs
320}
321
322pub fn scalar_subtype_difference(sub: &ScalarType, sup: &ScalarType) -> Vec<ColumnTypeDifference> {
326 use ScalarType::*;
327
328 let mut diffs = Vec::new();
329
330 match (sub, sup) {
331 (
332 List {
333 element_type: sub_elt,
334 ..
335 },
336 List {
337 element_type: sup_elt,
338 ..
339 },
340 )
341 | (
342 Map {
343 value_type: sub_elt,
344 ..
345 },
346 Map {
347 value_type: sup_elt,
348 ..
349 },
350 )
351 | (
352 Range {
353 element_type: sub_elt,
354 ..
355 },
356 Range {
357 element_type: sup_elt,
358 ..
359 },
360 )
361 | (Array(sub_elt), Array(sup_elt)) => {
362 let ctor = format!("{:?}", ScalarBaseType::from(sub));
363 diffs.extend(
364 scalar_subtype_difference(sub_elt, sup_elt)
365 .into_iter()
366 .map(|diff| ColumnTypeDifference::ElementType {
367 ctor: ctor.clone(),
368 element_type: Box::new(diff),
369 }),
370 );
371 }
372 (
373 Record {
374 fields: sub_fields, ..
375 },
376 Record {
377 fields: sup_fields, ..
378 },
379 ) => {
380 let sub = sub_fields
381 .iter()
382 .map(|(sub_field, sub_ty)| (sub_field.clone(), sub_ty))
383 .collect::<BTreeMap<_, _>>();
384
385 let mut missing = Vec::new();
386 let mut field_diffs = Vec::new();
387 for (sup_field, sup_ty) in sup_fields {
388 if let Some(sub_ty) = sub.get(sup_field) {
389 let diff = column_subtype_difference(sub_ty, sup_ty);
390
391 if !diff.is_empty() {
392 field_diffs.push((sup_field.clone(), diff));
393 }
394 } else {
395 missing.push(sup_field.clone());
396 }
397 }
398 }
399 (_, _) => {
400 if ScalarBaseType::from(sub) != ScalarBaseType::from(sup) {
402 diffs.push(ColumnTypeDifference::NotSubtype {
403 sub: sub.clone(),
404 sup: sup.clone(),
405 })
406 }
407 }
408 };
409
410 diffs
411}
412
413pub fn is_subtype_of(sub: &[ColumnType], sup: &[ColumnType]) -> bool {
418 if sub.len() != sup.len() {
419 return false;
420 }
421
422 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
423 (!known.nullable || got.nullable) && got.scalar_type.base_eq(&known.scalar_type)
424 })
425}
426
427#[derive(Debug)]
429pub struct Typecheck {
430 ctx: SharedContext,
432 disallow_new_globals: bool,
434 strict_join_equivalences: bool,
436 disallow_dummy: bool,
438 recursion_guard: RecursionGuard,
440}
441
442impl CheckedRecursion for Typecheck {
443 fn recursion_guard(&self) -> &RecursionGuard {
444 &self.recursion_guard
445 }
446}
447
448impl Typecheck {
449 pub fn new(ctx: SharedContext) -> Self {
451 Self {
452 ctx,
453 disallow_new_globals: false,
454 strict_join_equivalences: false,
455 disallow_dummy: false,
456 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
457 }
458 }
459
460 pub fn disallow_new_globals(mut self) -> Self {
464 self.disallow_new_globals = true;
465 self
466 }
467
468 pub fn strict_join_equivalences(mut self) -> Self {
472 self.strict_join_equivalences = true;
473
474 self
475 }
476
477 pub fn disallow_dummy(mut self) -> Self {
479 self.disallow_dummy = true;
480 self
481 }
482
483 pub fn typecheck<'a>(
494 &self,
495 expr: &'a MirRelationExpr,
496 ctx: &Context,
497 ) -> Result<Vec<ColumnType>, TypeError<'a>> {
498 use MirRelationExpr::*;
499
500 self.checked_recur(|tc| match expr {
501 Constant { typ, rows } => {
502 if let Ok(rows) = rows {
503 for (row, _id) in rows {
504 let datums = row.unpack();
505
506 if datums.len() != typ.column_types.len() {
508 return Err(TypeError::BadConstantRow {
509 source: expr,
510 got: row.clone(),
511 expected: typ.column_types.clone(),
512 });
513 }
514
515 if datums
517 .iter()
518 .zip_eq(typ.column_types.iter())
519 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of(ty))
520 {
521 return Err(TypeError::BadConstantRow {
522 source: expr,
523 got: row.clone(),
524 expected: typ.column_types.clone(),
525 });
526 }
527
528 if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
529 return Err(TypeError::DisallowedDummy {
530 source: expr,
531 });
532 }
533 }
534 }
535
536 Ok(typ.column_types.clone())
537 }
538 Get { typ, id, .. } => {
539 if let Id::Global(_global_id) = id {
540 if !ctx.contains_key(id) {
541 return Ok(typ.column_types.clone());
543 }
544 }
545
546 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
547 source: expr,
548 id: id.clone(),
549 typ: typ.clone(),
550 })?;
551
552 let diffs = relation_subtype_difference(&typ.column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
554
555 if !diffs.is_empty() {
556 return Err(TypeError::MismatchColumns {
557 source: expr,
558 got: typ.column_types.clone(),
559 expected: ctx_typ.clone(),
560 diffs,
561 message: "annotation did not match context type".into(),
562 });
563 }
564
565 Ok(typ.column_types.clone())
566 }
567 Project { input, outputs } => {
568 let t_in = tc.typecheck(input, ctx)?;
569
570 for x in outputs {
571 if *x >= t_in.len() {
572 return Err(TypeError::BadProject {
573 source: expr,
574 got: outputs.clone(),
575 input_type: t_in,
576 });
577 }
578 }
579
580 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
581 }
582 Map { input, scalars } => {
583 let mut t_in = tc.typecheck(input, ctx)?;
584
585 for scalar_expr in scalars.iter() {
586 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
587
588 if self.disallow_dummy && scalar_expr.contains_dummy() {
589 return Err(TypeError::DisallowedDummy {
590 source: expr,
591 });
592 }
593 }
594
595 Ok(t_in)
596 }
597 FlatMap { input, func, exprs } => {
598 let mut t_in = tc.typecheck(input, ctx)?;
599
600 let mut t_exprs = Vec::with_capacity(exprs.len());
601 for scalar_expr in exprs {
602 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
603
604 if self.disallow_dummy && scalar_expr.contains_dummy() {
605 return Err(TypeError::DisallowedDummy {
606 source: expr,
607 });
608 }
609 }
610 let t_out = func.output_type().column_types;
613
614 t_in.extend(t_out);
616 Ok(t_in)
617 }
618 Filter { input, predicates } => {
619 let mut t_in = tc.typecheck(input, ctx)?;
620
621 for column in non_nullable_columns(predicates) {
624 t_in[column].nullable = false;
625 }
626
627 for scalar_expr in predicates {
628 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
629
630 if t.scalar_type != ScalarType::Bool {
634 let sub = t.scalar_type.clone();
635
636 return Err(TypeError::MismatchColumn {
637 source: expr,
638 got: t,
639 expected: ColumnType {
640 scalar_type: ScalarType::Bool,
641 nullable: true,
642 },
643 diffs: vec![ColumnTypeDifference::NotSubtype { sub, sup: ScalarType::Bool }],
644 message: "expected boolean condition".into(),
645 });
646 }
647
648 if self.disallow_dummy && scalar_expr.contains_dummy() {
649 return Err(TypeError::DisallowedDummy {
650 source: expr,
651 });
652 }
653 }
654
655 Ok(t_in)
656 }
657 Join {
658 inputs,
659 equivalences,
660 implementation,
661 } => {
662 let mut t_in_global = Vec::new();
663 let mut t_in_local = vec![Vec::new(); inputs.len()];
664
665 for (i, input) in inputs.iter().enumerate() {
666 let input_t = tc.typecheck(input, ctx)?;
667 t_in_global.extend(input_t.clone());
668 t_in_local[i] = input_t;
669 }
670
671 for eq_class in equivalences {
672 let mut t_exprs: Vec<ColumnType> = Vec::with_capacity(eq_class.len());
673
674 let mut all_nullable = true;
675
676 for scalar_expr in eq_class {
677 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
679
680 if !t_expr.nullable {
681 all_nullable = false;
682 }
683
684 if let Some(t_first) = t_exprs.get(0) {
685 let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
686 if !diffs.is_empty() {
687 return Err(TypeError::MismatchColumn {
688 source: expr,
689 got: t_expr,
690 expected: t_first.clone(),
691 diffs,
692 message: "equivalence class members have different scalar types".into(),
693 });
694 }
695
696 if self.strict_join_equivalences {
700 if t_expr.nullable != t_first.nullable {
701 let sub = t_expr.clone();
702 let sup = t_first.clone();
703
704 let err = TypeError::MismatchColumn {
705 source: expr,
706 got: t_expr.clone(),
707 expected: t_first.clone(),
708 diffs: vec![ColumnTypeDifference::Nullability { sub, sup }],
709 message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
710 };
711
712 ::tracing::debug!("{err}");
714 }
715 }
716 }
717
718 if self.disallow_dummy && scalar_expr.contains_dummy() {
719 return Err(TypeError::DisallowedDummy {
720 source: expr,
721 });
722 }
723
724 t_exprs.push(t_expr);
725 }
726
727 if self.strict_join_equivalences && all_nullable {
728 let err = TypeError::BadJoinEquivalence {
729 source: expr,
730 got: t_exprs,
731 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
732 };
733
734 ::tracing::debug!("{err}");
736 }
737 }
738
739 match implementation {
741 JoinImplementation::Differential((start_idx, first_key, _), others) => {
742 if let Some(key) = first_key {
743 for k in key {
744 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
745 }
746 }
747
748 for (idx, key, _) in others {
749 for k in key {
750 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
751 }
752 }
753 }
754 JoinImplementation::DeltaQuery(plans) => {
755 for plan in plans {
756 for (idx, key, _) in plan {
757 for k in key {
758 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
759 }
760 }
761 }
762 }
763 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
764 let typ: Vec<ColumnType> = key
765 .iter()
766 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
767 .collect::<Result<Vec<ColumnType>, TypeError>>()?;
768
769 for row in consts {
770 let datums = row.unpack();
771
772 if datums.len() != typ.len() {
774 return Err(TypeError::BadConstantRow {
775 source: expr,
776 got: row.clone(),
777 expected: typ,
778 });
779 }
780
781 if datums
783 .iter()
784 .zip_eq(typ.iter())
785 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of(ty))
786 {
787 return Err(TypeError::BadConstantRow {
788 source: expr,
789 got: row.clone(),
790 expected: typ,
791 });
792 }
793 }
794 }
795 JoinImplementation::Unimplemented => (),
796 }
797
798 Ok(t_in_global)
799 }
800 Reduce {
801 input,
802 group_key,
803 aggregates,
804 monotonic: _,
805 expected_group_size: _,
806 } => {
807 let t_in = tc.typecheck(input, ctx)?;
808
809 let mut t_out = group_key
810 .iter()
811 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
812 .collect::<Result<Vec<_>, _>>()?;
813
814 if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
815 return Err(TypeError::DisallowedDummy {
816 source: expr,
817 });
818 }
819
820 for agg in aggregates {
821 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
822 }
823
824 Ok(t_out)
825 }
826 TopK {
827 input,
828 group_key,
829 order_key,
830 limit: _,
831 offset: _,
832 monotonic: _,
833 expected_group_size: _,
834 } => {
835 let t_in = tc.typecheck(input, ctx)?;
836
837 for &k in group_key {
838 if k >= t_in.len() {
839 return Err(TypeError::BadTopKGroupKey {
840 source: expr,
841 k,
842 input_type: t_in,
843 });
844 }
845 }
846
847 for order in order_key {
848 if order.column >= t_in.len() {
849 return Err(TypeError::BadTopKOrdering {
850 source: expr,
851 order: order.clone(),
852 input_type: t_in,
853 });
854 }
855 }
856
857 Ok(t_in)
858 }
859 Negate { input } => tc.typecheck(input, ctx),
860 Threshold { input } => tc.typecheck(input, ctx),
861 Union { base, inputs } => {
862 let mut t_base = tc.typecheck(base, ctx)?;
863
864 for input in inputs {
865 let t_input = tc.typecheck(input, ctx)?;
866
867 let len_sub = t_base.len();
868 let len_sup = t_input.len();
869 if len_sub != len_sup {
870 return Err(TypeError::MismatchColumns {
871 source: expr,
872 got: t_base.clone(),
873 expected: t_input,
874 diffs: vec![RelationTypeDifference::Length {
875 len_sub,
876 len_sup,
877 }],
878 message: "union branches have different numbers of columns".into(),
879 });
880 }
881
882 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
883 *base_col =
884 base_col
885 .union(&input_col)
886 .map_err(|e| {
887 let base_col = base_col.clone();
888 let diffs = column_subtype_difference(&base_col, &input_col);
889
890 TypeError::MismatchColumn {
891 source: expr,
892 got: input_col,
893 expected: base_col,
894 diffs,
895 message: format!(
896 "couldn't compute union of column types in union: {e}"
897 ),
898 }
899 })?;
900 }
901 }
902
903 Ok(t_base)
904 }
905 Let { id, value, body } => {
906 let t_value = tc.typecheck(value, ctx)?;
907
908 let binding = Id::Local(*id);
909 if ctx.contains_key(&binding) {
910 return Err(TypeError::Shadowing {
911 source: expr,
912 id: binding,
913 });
914 }
915
916 let mut body_ctx = ctx.clone();
917 body_ctx.insert(Id::Local(*id), t_value);
918
919 tc.typecheck(body, &body_ctx)
920 }
921 LetRec { ids, values, body, limits: _ } => {
922 if ids.len() != values.len() {
923 return Err(TypeError::BadLetRecBindings { source: expr });
924 }
925
926 let mut ctx = ctx.clone();
929 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
931 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
932 }
933
934 for (id, value) in ids.iter().zip_eq(values.iter()) {
935 let typ = tc.typecheck(value, &ctx)?;
936
937 let id = Id::Local(id.clone());
938 if let Some(ctx_typ) = ctx.get_mut(&id) {
939 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
940 *base_col = base_col.union(&input_col).map_err(|e| {
941 let base_col = base_col.clone();
942 let diffs = column_subtype_difference(&base_col, &input_col);
943
944 TypeError::MismatchColumn {
945 source: expr,
946 got: input_col,
947 expected: base_col,
948 diffs,
949 message: format!(
950 "couldn't compute union of column types in let rec: {e}"
951 ),
952 }
953 })?;
954 }
955 } else {
956 ctx.insert(id, typ);
958 }
959 }
960
961 tc.typecheck(body, &ctx)
962 }
963 ArrangeBy { input, keys } => {
964 let t_in = tc.typecheck(input, ctx)?;
965
966 for key in keys {
967 for k in key {
968 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
969 }
970 }
971
972 Ok(t_in)
973 }
974 })
975 }
976
977 fn collect_recursive_variable_types<'a>(
981 &self,
982 expr: &'a MirRelationExpr,
983 ids: &[LocalId],
984 ctx: &mut Context,
985 ) -> Result<(), TypeError<'a>> {
986 use MirRelationExpr::*;
987
988 self.checked_recur(|tc| {
989 match expr {
990 Get {
991 id: Id::Local(id),
992 typ,
993 ..
994 } => {
995 if !ids.contains(id) {
996 return Ok(());
997 }
998
999 let id = Id::Local(id.clone());
1000 if let Some(ctx_typ) = ctx.get_mut(&id) {
1001 for (base_col, input_col) in
1002 ctx_typ.iter_mut().zip_eq(typ.column_types.iter())
1003 {
1004 *base_col = base_col.union(input_col).map_err(|e| {
1005 let base_col = base_col.clone();
1006 let diffs = column_subtype_difference(&base_col, input_col);
1007
1008 TypeError::MismatchColumn {
1009 source: expr,
1010 got: input_col.clone(),
1011 expected: base_col,
1012 diffs,
1013 message: format!(
1014 "couldn't compute union of collected column types: {}",
1015 e
1016 ),
1017 }
1018 })?;
1019 }
1020 } else {
1021 ctx.insert(id, typ.column_types.clone());
1022 }
1023 }
1024 Get {
1025 id: Id::Global(..), ..
1026 }
1027 | Constant { .. } => (),
1028 Let { id, value, body } => {
1029 tc.collect_recursive_variable_types(value, ids, ctx)?;
1030
1031 if ids.contains(id) {
1033 return Err(TypeError::Shadowing {
1034 source: expr,
1035 id: Id::Local(*id),
1036 });
1037 }
1038
1039 tc.collect_recursive_variable_types(body, ids, ctx)?;
1040 }
1041 LetRec {
1042 ids: inner_ids,
1043 values,
1044 body,
1045 limits: _,
1046 } => {
1047 for inner_id in inner_ids {
1048 if ids.contains(inner_id) {
1049 return Err(TypeError::Shadowing {
1050 source: expr,
1051 id: Id::Local(*inner_id),
1052 });
1053 }
1054 }
1055
1056 for value in values {
1057 tc.collect_recursive_variable_types(value, ids, ctx)?;
1058 }
1059
1060 tc.collect_recursive_variable_types(body, ids, ctx)?;
1061 }
1062 Project { input, .. }
1063 | Map { input, .. }
1064 | FlatMap { input, .. }
1065 | Filter { input, .. }
1066 | Reduce { input, .. }
1067 | TopK { input, .. }
1068 | Negate { input }
1069 | Threshold { input }
1070 | ArrangeBy { input, .. } => {
1071 tc.collect_recursive_variable_types(input, ids, ctx)?;
1072 }
1073 Join { inputs, .. } => {
1074 for input in inputs {
1075 tc.collect_recursive_variable_types(input, ids, ctx)?;
1076 }
1077 }
1078 Union { base, inputs } => {
1079 tc.collect_recursive_variable_types(base, ids, ctx)?;
1080
1081 for input in inputs {
1082 tc.collect_recursive_variable_types(input, ids, ctx)?;
1083 }
1084 }
1085 }
1086
1087 Ok(())
1088 })
1089 }
1090
1091 fn typecheck_scalar<'a>(
1092 &self,
1093 expr: &'a MirScalarExpr,
1094 source: &'a MirRelationExpr,
1095 column_types: &[ColumnType],
1096 ) -> Result<ColumnType, TypeError<'a>> {
1097 use MirScalarExpr::*;
1098
1099 self.checked_recur(|tc| match expr {
1100 Column(i) => match column_types.get(*i) {
1101 Some(ty) => Ok(ty.clone()),
1102 None => Err(TypeError::NoSuchColumn {
1103 source,
1104 expr,
1105 col: *i,
1106 }),
1107 },
1108 Literal(row, typ) => {
1109 if let Ok(row) = row {
1110 let datums = row.unpack();
1111
1112 if datums.len() != 1
1113 || (datums[0] != mz_repr::Datum::Dummy && !datums[0].is_instance_of(typ))
1114 {
1115 return Err(TypeError::BadConstantRow {
1116 source,
1117 got: row.clone(),
1118 expected: vec![typ.clone()],
1119 });
1120 }
1121 }
1122
1123 Ok(typ.clone())
1124 }
1125 CallUnmaterializable(func) => Ok(func.output_type()),
1126 CallUnary { expr, func } => {
1127 Ok(func.output_type(tc.typecheck_scalar(expr, source, column_types)?))
1128 }
1129 CallBinary { expr1, expr2, func } => Ok(func.output_type(
1130 tc.typecheck_scalar(expr1, source, column_types)?,
1131 tc.typecheck_scalar(expr2, source, column_types)?,
1132 )),
1133 CallVariadic { exprs, func } => Ok(func.output_type(
1134 exprs
1135 .iter()
1136 .map(|e| tc.typecheck_scalar(e, source, column_types))
1137 .collect::<Result<Vec<_>, TypeError>>()?,
1138 )),
1139 If { cond, then, els } => {
1140 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1141
1142 if cond_type.scalar_type != ScalarType::Bool {
1146 let sub = cond_type.scalar_type.clone();
1147
1148 return Err(TypeError::MismatchColumn {
1149 source,
1150 got: cond_type,
1151 expected: ColumnType {
1152 scalar_type: ScalarType::Bool,
1153 nullable: true,
1154 },
1155 diffs: vec![ColumnTypeDifference::NotSubtype {
1156 sub,
1157 sup: ScalarType::Bool,
1158 }],
1159 message: "expected boolean condition".into(),
1160 });
1161 }
1162
1163 let then_type = tc.typecheck_scalar(then, source, column_types)?;
1164 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1165 then_type.union(&else_type).map_err(|e| {
1166 let diffs = column_subtype_difference(&then_type, &else_type);
1167
1168 TypeError::MismatchColumn {
1169 source,
1170 got: then_type,
1171 expected: else_type,
1172 diffs,
1173 message: format!("couldn't compute union of column types for if: {e}"),
1174 }
1175 })
1176 }
1177 })
1178 }
1179
1180 pub fn typecheck_aggregate<'a>(
1182 &self,
1183 expr: &'a AggregateExpr,
1184 source: &'a MirRelationExpr,
1185 column_types: &[ColumnType],
1186 ) -> Result<ColumnType, TypeError<'a>> {
1187 self.checked_recur(|tc| {
1188 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1189
1190 Ok(expr.func.output_type(t_in))
1193 })
1194 }
1195}
1196
1197macro_rules! type_error {
1201 ($severity:expr, $($arg:tt)+) => {{
1202 if $severity {
1203 ::tracing::warn!($($arg)+);
1204 ::tracing::error!("type error in MIR optimization (details in warning; see 'Type error omnibus' issue database-issues#5663 <https://github.com/MaterializeInc/database-issues/issues/5663>)");
1205 } else {
1206 ::tracing::debug!($($arg)+);
1207 }
1208 }}
1209}
1210
1211impl crate::Transform for Typecheck {
1212 fn name(&self) -> &'static str {
1213 "Typecheck"
1214 }
1215
1216 fn actually_perform_transform(
1217 &self,
1218 relation: &mut MirRelationExpr,
1219 transform_ctx: &mut crate::TransformCtx,
1220 ) -> Result<(), crate::TransformError> {
1221 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1222
1223 let expected = transform_ctx
1224 .global_id
1225 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1226
1227 if let Some(id) = transform_ctx.global_id {
1228 if self.disallow_new_globals
1229 && expected.is_none()
1230 && transform_ctx.global_id.is_some()
1231 && !id.is_transient()
1232 {
1233 type_error!(
1234 false, "TYPE WARNING: NEW NON-TRANSIENT GLOBAL ID {id}\n{}",
1236 relation.pretty()
1237 );
1238 }
1239 }
1240
1241 let got = self.typecheck(relation, &typecheck_ctx);
1242
1243 let humanizer = mz_repr::explain::DummyHumanizer;
1244
1245 match (got, expected) {
1246 (Ok(got), Some(expected)) => {
1247 let id = transform_ctx.global_id.unwrap();
1248
1249 let diffs = relation_subtype_difference(expected, &got);
1251 if !diffs.is_empty() {
1252 let severity = diffs
1254 .iter()
1255 .any(|diff| diff.clone().ignore_nullability().is_some());
1256
1257 let err = TypeError::MismatchColumns {
1258 source: relation,
1259 got,
1260 expected: expected.clone(),
1261 diffs,
1262 message: format!(
1263 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1264 ),
1265 };
1266
1267 type_error!(severity, "TYPE ERROR IN KNOWN GLOBAL ID {id}:\n{err}");
1268 }
1269 }
1270 (Ok(got), None) => {
1271 if let Some(id) = transform_ctx.global_id {
1272 typecheck_ctx.insert(Id::Global(id), got);
1273 }
1274 }
1275 (Err(err), _) => {
1276 let (expected, binding) = match expected {
1277 Some(expected) => {
1278 let id = transform_ctx.global_id.unwrap();
1279 (
1280 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1281 format!("KNOWN GLOBAL ID {id}"),
1282 )
1283 }
1284 None => ("".to_string(), "TRANSIENT QUERY".to_string()),
1285 };
1286
1287 type_error!(
1288 true, "TYPE ERROR IN {binding}:\n{err}\n{expected}{}",
1290 relation.pretty()
1291 );
1292 }
1293 }
1294
1295 Ok(())
1296 }
1297}
1298
1299pub fn columns_pretty<H>(cols: &[ColumnType], humanizer: &H) -> String
1301where
1302 H: ExprHumanizer,
1303{
1304 let mut s = String::with_capacity(2 + 3 * cols.len());
1305
1306 s.push('(');
1307
1308 let mut it = cols.iter().peekable();
1309 while let Some(col) = it.next() {
1310 s.push_str(&humanizer.humanize_column_type(col, false));
1311
1312 if it.peek().is_some() {
1313 s.push_str(", ");
1314 }
1315 }
1316
1317 s.push(')');
1318
1319 s
1320}
1321
1322impl RelationTypeDifference {
1323 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1327 where
1328 H: ExprHumanizer,
1329 {
1330 use RelationTypeDifference::*;
1331 match self {
1332 Length { len_sub, len_sup } => {
1333 writeln!(
1334 f,
1335 " number of columns do not match ({len_sub} != {len_sup})"
1336 )
1337 }
1338 Column { col, diff } => {
1339 writeln!(f, " column {col} differs:")?;
1340 diff.humanize(4, h, f)
1341 }
1342 }
1343 }
1344}
1345
1346impl ColumnTypeDifference {
1347 pub fn humanize<H>(
1349 &self,
1350 indent: usize,
1351 h: &H,
1352 f: &mut std::fmt::Formatter<'_>,
1353 ) -> std::fmt::Result
1354 where
1355 H: ExprHumanizer,
1356 {
1357 use ColumnTypeDifference::*;
1358
1359 write!(f, "{:indent$}", "")?;
1361
1362 match self {
1363 NotSubtype { sub, sup } => {
1364 let sub = h.humanize_scalar_type(sub, false);
1365 let sup = h.humanize_scalar_type(sup, false);
1366
1367 writeln!(f, "{sub} is a not a subtype of {sup}")
1368 }
1369 Nullability { sub, sup } => {
1370 let sub = h.humanize_column_type(sub, false);
1371 let sup = h.humanize_column_type(sup, false);
1372
1373 writeln!(f, "{sub} is nullable but {sup} is not")
1374 }
1375 ElementType { ctor, element_type } => {
1376 writeln!(f, "{ctor} element types differ:")?;
1377
1378 element_type.humanize(indent + 2, h, f)
1379 }
1380 RecordMissingFields { missing } => {
1381 write!(f, "missing column fields:")?;
1382 for col in missing {
1383 write!(f, " {col}")?;
1384 }
1385 f.write_char('\n')
1386 }
1387 RecordFields { fields } => {
1388 writeln!(f, "{} record fields differ:", fields.len())?;
1389
1390 for (col, diff) in fields {
1391 writeln!(f, "{:indent$} field '{col}':", "")?;
1392 diff.humanize(indent + 4, h, f)?;
1393 }
1394 Ok(())
1395 }
1396 }
1397 }
1398}
1399
1400#[allow(missing_debug_implementations)]
1402pub struct TypeErrorHumanizer<'a, 'b, H>
1403where
1404 H: ExprHumanizer,
1405{
1406 err: &'a TypeError<'a>,
1407 humanizer: &'b H,
1408}
1409
1410impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1411where
1412 H: ExprHumanizer,
1413{
1414 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1416 Self { err, humanizer }
1417 }
1418}
1419
1420impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1421where
1422 H: ExprHumanizer,
1423{
1424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1425 self.err.humanize(self.humanizer, f)
1426 }
1427}
1428
1429impl<'a> std::fmt::Display for TypeError<'a> {
1430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1431 TypeErrorHumanizer {
1432 err: self,
1433 humanizer: &DummyHumanizer,
1434 }
1435 .fmt(f)
1436 }
1437}
1438
1439impl<'a> TypeError<'a> {
1440 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1442 use TypeError::*;
1443 match self {
1444 Unbound { source, .. }
1445 | NoSuchColumn { source, .. }
1446 | MismatchColumn { source, .. }
1447 | MismatchColumns { source, .. }
1448 | BadConstantRow { source, .. }
1449 | BadProject { source, .. }
1450 | BadJoinEquivalence { source, .. }
1451 | BadTopKGroupKey { source, .. }
1452 | BadTopKOrdering { source, .. }
1453 | BadLetRecBindings { source }
1454 | Shadowing { source, .. }
1455 | DisallowedDummy { source, .. } => Some(source),
1456 Recursion { .. } => None,
1457 }
1458 }
1459
1460 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1461 where
1462 H: ExprHumanizer,
1463 {
1464 if let Some(source) = self.source() {
1465 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1466 }
1467
1468 use TypeError::*;
1469 match self {
1470 Unbound { source: _, id, typ } => {
1471 let typ = columns_pretty(&typ.column_types, humanizer);
1472 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1473 }
1474 NoSuchColumn {
1475 source: _,
1476 expr,
1477 col,
1478 } => writeln!(f, "{expr} references non-existent column {col}")?,
1479 MismatchColumn {
1480 source: _,
1481 got,
1482 expected,
1483 diffs,
1484 message,
1485 } => {
1486 let got = humanizer.humanize_column_type(got, false);
1487 let expected = humanizer.humanize_column_type(expected, false);
1488 writeln!(
1489 f,
1490 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1491 )?;
1492
1493 for diff in diffs {
1494 diff.humanize(2, humanizer, f)?;
1495 }
1496 }
1497 MismatchColumns {
1498 source: _,
1499 got,
1500 expected,
1501 diffs,
1502 message,
1503 } => {
1504 let got = columns_pretty(got, humanizer);
1505 let expected = columns_pretty(expected, humanizer);
1506
1507 writeln!(
1508 f,
1509 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1510 )?;
1511
1512 for diff in diffs {
1513 diff.humanize(humanizer, f)?;
1514 }
1515 }
1516 BadConstantRow {
1517 source: _,
1518 got,
1519 expected,
1520 } => {
1521 let expected = columns_pretty(expected, humanizer);
1522
1523 writeln!(
1524 f,
1525 "bad constant row\n got {got}\nexpected row of type {expected}"
1526 )?
1527 }
1528 BadProject {
1529 source: _,
1530 got,
1531 input_type,
1532 } => {
1533 let input_type = columns_pretty(input_type, humanizer);
1534
1535 writeln!(
1536 f,
1537 "projection of non-existant columns {got:?} from type {input_type}"
1538 )?
1539 }
1540 BadJoinEquivalence {
1541 source: _,
1542 got,
1543 message,
1544 } => {
1545 let got = columns_pretty(got, humanizer);
1546
1547 writeln!(f, "bad join equivalence {got}: {message}")?
1548 }
1549 BadTopKGroupKey {
1550 source: _,
1551 k,
1552 input_type,
1553 } => {
1554 let input_type = columns_pretty(input_type, humanizer);
1555
1556 writeln!(
1557 f,
1558 "TopK group key component references invalid column {k} in columns: {input_type}"
1559 )?
1560 }
1561 BadTopKOrdering {
1562 source: _,
1563 order,
1564 input_type,
1565 } => {
1566 let col = order.column;
1567 let num_cols = input_type.len();
1568 let are = if num_cols == 1 { "is" } else { "are" };
1569 let s = if num_cols == 1 { "" } else { "s" };
1570 let input_type = columns_pretty(input_type, humanizer);
1571
1572 let mode = HumanizedExplain::new(false);
1574 let order = mode.expr(order, None);
1575
1576 writeln!(
1577 f,
1578 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
1579 )?
1580 }
1581 BadLetRecBindings { source: _ } => {
1582 writeln!(f, "LetRec ids and definitions don't line up")?
1583 }
1584 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
1585 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
1586 Recursion { error } => writeln!(f, "{error}")?,
1587 }
1588
1589 Ok(())
1590 }
1591}