1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
25use mz_repr::{
26 ColumnName, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType, Row,
27 SqlColumnType,
28};
29
30pub type SharedContext = Arc<Mutex<Context>>;
35
36pub fn empty_context() -> SharedContext {
38 Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47 Unbound {
49 source: &'a MirRelationExpr,
51 id: Id,
53 typ: ReprRelationType,
55 },
56 NoSuchColumn {
58 source: &'a MirRelationExpr,
60 expr: &'a MirScalarExpr,
62 col: usize,
64 },
65 MismatchColumn {
67 source: &'a MirRelationExpr,
69 got: ReprColumnType,
71 expected: ReprColumnType,
73 diffs: Vec<ReprColumnTypeDifference>,
75 message: String,
77 },
78 MismatchColumns {
80 source: &'a MirRelationExpr,
82 got: Vec<ReprColumnType>,
84 expected: Vec<ReprColumnType>,
86 diffs: Vec<ReprRelationTypeDifference>,
88 message: String,
90 },
91 BadConstantRow {
93 source: &'a MirRelationExpr,
95 got: Row,
97 expected: Vec<ReprColumnType>,
99 },
101 BadProject {
103 source: &'a MirRelationExpr,
105 got: Vec<usize>,
107 input_type: Vec<ReprColumnType>,
109 },
110 BadJoinEquivalence {
112 source: &'a MirRelationExpr,
114 got: Vec<ReprColumnType>,
116 message: String,
118 },
119 BadTopKGroupKey {
121 source: &'a MirRelationExpr,
123 k: usize,
125 input_type: Vec<ReprColumnType>,
127 },
128 BadTopKOrdering {
130 source: &'a MirRelationExpr,
132 order: ColumnOrder,
134 input_type: Vec<ReprColumnType>,
136 },
137 BadLetRecBindings {
139 source: &'a MirRelationExpr,
141 },
142 Shadowing {
144 source: &'a MirRelationExpr,
146 id: Id,
148 },
149 Recursion {
151 error: RecursionLimitError,
153 },
154 DisallowedDummy {
156 source: &'a MirRelationExpr,
158 },
159}
160
161impl<'a> From<RecursionLimitError> for TypeError<'a> {
162 fn from(error: RecursionLimitError) -> Self {
163 TypeError::Recursion { error }
164 }
165}
166
167type Context = BTreeMap<Id, Vec<ReprColumnType>>;
168
169#[derive(Clone, Debug, Hash)]
173pub enum ReprRelationTypeDifference {
174 Length {
176 len_sub: usize,
178 len_sup: usize,
180 },
181 Column {
183 col: usize,
185 diff: ReprColumnTypeDifference,
187 },
188}
189
190#[derive(Clone, Debug, Hash)]
195pub enum ReprColumnTypeDifference {
196 NotSubtype {
198 sub: ReprScalarType,
200 sup: ReprScalarType,
202 },
203 Nullability {
205 sub: ReprColumnType,
207 sup: ReprColumnType,
209 },
210 ElementType {
212 ctor: String,
214 element_type: Box<ReprColumnTypeDifference>,
216 },
217 RecordMissingFields {
219 missing: Vec<ColumnName>,
221 },
222 RecordFields {
224 fields: Vec<ReprColumnTypeDifference>,
226 },
227}
228
229impl ReprRelationTypeDifference {
230 pub fn ignore_nullability(self) -> Option<Self> {
234 use ReprRelationTypeDifference::*;
235
236 match self {
237 Length { .. } => Some(self),
238 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
239 }
240 }
241}
242
243impl ReprColumnTypeDifference {
244 pub fn ignore_nullability(self) -> Option<Self> {
248 use ReprColumnTypeDifference::*;
249
250 match self {
251 Nullability { .. } => None,
252 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
253 ElementType { ctor, element_type } => {
254 element_type
255 .ignore_nullability()
256 .map(|element_type| ElementType {
257 ctor,
258 element_type: Box::new(element_type),
259 })
260 }
261 RecordFields { fields } => {
262 let fields = fields
263 .into_iter()
264 .flat_map(|diff| diff.ignore_nullability())
265 .collect::<Vec<_>>();
266
267 if fields.is_empty() {
268 None
269 } else {
270 Some(RecordFields { fields })
271 }
272 }
273 }
274 }
275}
276
277pub fn relation_subtype_difference(
281 sub: &[ReprColumnType],
282 sup: &[ReprColumnType],
283) -> Vec<ReprRelationTypeDifference> {
284 let mut diffs = Vec::new();
285
286 if sub.len() != sup.len() {
287 diffs.push(ReprRelationTypeDifference::Length {
288 len_sub: sub.len(),
289 len_sup: sup.len(),
290 });
291
292 return diffs;
294 }
295
296 diffs.extend(
297 sub.iter()
298 .zip_eq(sup.iter())
299 .enumerate()
300 .flat_map(|(col, (sub_ty, sup_ty))| {
301 column_subtype_difference(sub_ty, sup_ty)
302 .into_iter()
303 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
304 }),
305 );
306
307 diffs
308}
309
310pub fn column_subtype_difference(
314 sub: &ReprColumnType,
315 sup: &ReprColumnType,
316) -> Vec<ReprColumnTypeDifference> {
317 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
318
319 if sub.nullable && !sup.nullable {
320 diffs.push(ReprColumnTypeDifference::Nullability {
321 sub: sub.clone(),
322 sup: sup.clone(),
323 });
324 }
325
326 diffs
327}
328
329pub fn scalar_subtype_difference(
333 sub: &ReprScalarType,
334 sup: &ReprScalarType,
335) -> Vec<ReprColumnTypeDifference> {
336 use ReprScalarType::*;
337
338 let mut diffs = Vec::new();
339
340 match (sub, sup) {
341 (
342 List {
343 element_type: sub_elt,
344 ..
345 },
346 List {
347 element_type: sup_elt,
348 ..
349 },
350 )
351 | (
352 Map {
353 value_type: sub_elt,
354 ..
355 },
356 Map {
357 value_type: sup_elt,
358 ..
359 },
360 )
361 | (
362 Range {
363 element_type: sub_elt,
364 ..
365 },
366 Range {
367 element_type: sup_elt,
368 ..
369 },
370 )
371 | (Array(sub_elt), Array(sup_elt)) => {
372 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
373 diffs.extend(
374 scalar_subtype_difference(sub_elt, sup_elt)
375 .into_iter()
376 .map(|diff| ReprColumnTypeDifference::ElementType {
377 ctor: ctor.clone(),
378 element_type: Box::new(diff),
379 }),
380 );
381 }
382 (
383 Record {
384 fields: sub_fields, ..
385 },
386 Record {
387 fields: sup_fields, ..
388 },
389 ) => {
390 if sub_fields.len() != sup_fields.len() {
391 diffs.push(ReprColumnTypeDifference::NotSubtype {
392 sub: sub.clone(),
393 sup: sup.clone(),
394 });
395 return diffs;
396 }
397
398 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
399 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
400 }
401 }
402 (_, _) => {
403 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
404 diffs.push(ReprColumnTypeDifference::NotSubtype {
405 sub: sub.clone(),
406 sup: sup.clone(),
407 })
408 }
409 }
410 };
411
412 diffs
413}
414
415pub fn scalar_union(
419 typ: &mut ReprScalarType,
420 other: &ReprScalarType,
421) -> Vec<ReprColumnTypeDifference> {
422 use ReprScalarType::*;
423
424 let mut diffs = Vec::new();
425
426 let ctor = ReprScalarBaseType::from(&*typ);
428 match (typ, other) {
429 (
430 List {
431 element_type: typ_elt,
432 },
433 List {
434 element_type: other_elt,
435 },
436 )
437 | (
438 Map {
439 value_type: typ_elt,
440 },
441 Map {
442 value_type: other_elt,
443 },
444 )
445 | (
446 Range {
447 element_type: typ_elt,
448 },
449 Range {
450 element_type: other_elt,
451 },
452 )
453 | (Array(typ_elt), Array(other_elt)) => {
454 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
455 diffs.extend(
456 res.into_iter()
457 .map(|diff| ReprColumnTypeDifference::ElementType {
458 ctor: format!("{ctor:?}"),
459 element_type: Box::new(diff),
460 }),
461 );
462 }
463 (
464 Record { fields: typ_fields },
465 Record {
466 fields: other_fields,
467 },
468 ) => {
469 if typ_fields.len() != other_fields.len() {
470 diffs.push(ReprColumnTypeDifference::NotSubtype {
471 sub: ReprScalarType::Record {
472 fields: typ_fields.clone(),
473 },
474 sup: other.clone(),
475 });
476 return diffs;
477 }
478
479 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
480 diffs.extend(column_union(typ_ty, other_ty));
481 }
482 }
483 (typ, _) => {
484 if ctor != ReprScalarBaseType::from(other) {
485 diffs.push(ReprColumnTypeDifference::NotSubtype {
486 sub: typ.clone(),
487 sup: other.clone(),
488 })
489 }
490 }
491 };
492
493 diffs
494}
495
496pub fn column_union(
500 typ: &mut ReprColumnType,
501 other: &ReprColumnType,
502) -> Vec<ReprColumnTypeDifference> {
503 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
504
505 if diffs.is_empty() {
506 typ.nullable |= other.nullable;
507 }
508
509 diffs
510}
511
512pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
517 if sub.len() != sup.len() {
518 return false;
519 }
520
521 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
522 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
523 })
524}
525
526#[derive(Debug)]
528pub struct Typecheck {
529 ctx: SharedContext,
531 disallow_new_globals: bool,
533 strict_join_equivalences: bool,
535 disallow_dummy: bool,
537 recursion_guard: RecursionGuard,
539}
540
541impl CheckedRecursion for Typecheck {
542 fn recursion_guard(&self) -> &RecursionGuard {
543 &self.recursion_guard
544 }
545}
546
547impl Typecheck {
548 pub fn new(ctx: SharedContext) -> Self {
550 Self {
551 ctx,
552 disallow_new_globals: false,
553 strict_join_equivalences: false,
554 disallow_dummy: false,
555 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
556 }
557 }
558
559 pub fn disallow_new_globals(mut self) -> Self {
563 self.disallow_new_globals = true;
564 self
565 }
566
567 pub fn strict_join_equivalences(mut self) -> Self {
571 self.strict_join_equivalences = true;
572
573 self
574 }
575
576 pub fn disallow_dummy(mut self) -> Self {
578 self.disallow_dummy = true;
579 self
580 }
581
582 pub fn typecheck<'a>(
593 &self,
594 expr: &'a MirRelationExpr,
595 ctx: &Context,
596 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
597 use MirRelationExpr::*;
598
599 self.checked_recur(|tc| match expr {
600 Constant { typ, rows } => {
601 if let Ok(rows) = rows {
602 for (row, _id) in rows {
603 let datums = row.unpack();
604
605 if datums.len() != typ.column_types.len() {
607 return Err(TypeError::BadConstantRow {
608 source: expr,
609 got: row.clone(),
610 expected: typ.column_types.iter().map(ReprColumnType::from).collect(),
611 });
612 }
613
614 if datums
616 .iter()
617 .zip_eq(typ.column_types.iter())
618 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of_sql(ty))
619 {
620 return Err(TypeError::BadConstantRow {
621 source: expr,
622 got: row.clone(),
623 expected: typ.column_types.iter().map(ReprColumnType::from).collect(),
624 });
625 }
626
627 if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
628 return Err(TypeError::DisallowedDummy {
629 source: expr,
630 });
631 }
632 }
633 }
634
635 Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
636 }
637 Get { typ, id, .. } => {
638 if let Id::Global(_global_id) = id {
639 if !ctx.contains_key(id) {
640 return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
642 }
643 }
644
645 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
646 source: expr,
647 id: id.clone(),
648 typ: ReprRelationType::from(typ),
649 })?;
650
651 let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
652
653 let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
655
656 if !diffs.is_empty() {
657 return Err(TypeError::MismatchColumns {
658 source: expr,
659 got: column_types,
660 expected: ctx_typ.clone(),
661 diffs,
662 message: "annotation did not match context type".to_string(),
663 });
664 }
665
666 Ok(column_types)
667 }
668 Project { input, outputs } => {
669 let t_in = tc.typecheck(input, ctx)?;
670
671 for x in outputs {
672 if *x >= t_in.len() {
673 return Err(TypeError::BadProject {
674 source: expr,
675 got: outputs.clone(),
676 input_type: t_in,
677 });
678 }
679 }
680
681 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
682 }
683 Map { input, scalars } => {
684 let mut t_in = tc.typecheck(input, ctx)?;
685
686 for scalar_expr in scalars.iter() {
687 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
688
689 if self.disallow_dummy && scalar_expr.contains_dummy() {
690 return Err(TypeError::DisallowedDummy {
691 source: expr,
692 });
693 }
694 }
695
696 Ok(t_in)
697 }
698 FlatMap { input, func, exprs } => {
699 let mut t_in = tc.typecheck(input, ctx)?;
700
701 let mut t_exprs = Vec::with_capacity(exprs.len());
702 for scalar_expr in exprs {
703 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
704
705 if self.disallow_dummy && scalar_expr.contains_dummy() {
706 return Err(TypeError::DisallowedDummy {
707 source: expr,
708 });
709 }
710 }
711 let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
714
715 t_in.extend(t_out);
717 Ok(t_in)
718 }
719 Filter { input, predicates } => {
720 let mut t_in = tc.typecheck(input, ctx)?;
721
722 for column in non_nullable_columns(predicates) {
725 t_in[column].nullable = false;
726 }
727
728 for scalar_expr in predicates {
729 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
730
731 if t.scalar_type != ReprScalarType::Bool {
735 let sub = t.scalar_type.clone();
736
737 return Err(TypeError::MismatchColumn {
738 source: expr,
739 got: t,
740 expected: ReprColumnType {
741 scalar_type: ReprScalarType::Bool,
742 nullable: true,
743 },
744 diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
745 message: "expected boolean condition".to_string(),
746 });
747 }
748
749 if self.disallow_dummy && scalar_expr.contains_dummy() {
750 return Err(TypeError::DisallowedDummy {
751 source: expr,
752 });
753 }
754 }
755
756 Ok(t_in)
757 }
758 Join {
759 inputs,
760 equivalences,
761 implementation,
762 } => {
763 let mut t_in_global = Vec::new();
764 let mut t_in_local = vec![Vec::new(); inputs.len()];
765
766 for (i, input) in inputs.iter().enumerate() {
767 let input_t = tc.typecheck(input, ctx)?;
768 t_in_global.extend(input_t.clone());
769 t_in_local[i] = input_t;
770 }
771
772 for eq_class in equivalences {
773 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
774
775 let mut all_nullable = true;
776
777 for scalar_expr in eq_class {
778 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
780
781 if !t_expr.nullable {
782 all_nullable = false;
783 }
784
785 if let Some(t_first) = t_exprs.get(0) {
786 let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
787 if !diffs.is_empty() {
788 return Err(TypeError::MismatchColumn {
789 source: expr,
790 got: t_expr,
791 expected: t_first.clone(),
792 diffs,
793 message: "equivalence class members have different scalar types".to_string(),
794 });
795 }
796
797 if self.strict_join_equivalences {
801 if t_expr.nullable != t_first.nullable {
802 let sub = t_expr.clone();
803 let sup = t_first.clone();
804
805 let err = TypeError::MismatchColumn {
806 source: expr,
807 got: t_expr.clone(),
808 expected: t_first.clone(),
809 diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
810 message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
811 };
812
813 ::tracing::debug!("{err}");
815 }
816 }
817 }
818
819 if self.disallow_dummy && scalar_expr.contains_dummy() {
820 return Err(TypeError::DisallowedDummy {
821 source: expr,
822 });
823 }
824
825 t_exprs.push(t_expr);
826 }
827
828 if self.strict_join_equivalences && all_nullable {
829 let err = TypeError::BadJoinEquivalence {
830 source: expr,
831 got: t_exprs,
832 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
833 };
834
835 ::tracing::debug!("{err}");
837 }
838 }
839
840 match implementation {
842 JoinImplementation::Differential((start_idx, first_key, _), others) => {
843 if let Some(key) = first_key {
844 for k in key {
845 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
846 }
847 }
848
849 for (idx, key, _) in others {
850 for k in key {
851 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
852 }
853 }
854 }
855 JoinImplementation::DeltaQuery(plans) => {
856 for plan in plans {
857 for (idx, key, _) in plan {
858 for k in key {
859 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
860 }
861 }
862 }
863 }
864 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
865 let typ: Vec<ReprColumnType> = key
866 .iter()
867 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
868 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
869
870 for row in consts {
871 let datums = row.unpack();
872
873 if datums.len() != typ.len() {
875 return Err(TypeError::BadConstantRow {
876 source: expr,
877 got: row.clone(),
878 expected: typ,
879 });
880 }
881
882 if datums
884 .iter()
885 .zip_eq(typ.iter())
886 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of(ty))
887 {
888 return Err(TypeError::BadConstantRow {
889 source: expr,
890 got: row.clone(),
891 expected: typ,
892 });
893 }
894 }
895 }
896 JoinImplementation::Unimplemented => (),
897 }
898
899 Ok(t_in_global)
900 }
901 Reduce {
902 input,
903 group_key,
904 aggregates,
905 monotonic: _,
906 expected_group_size: _,
907 } => {
908 let t_in = tc.typecheck(input, ctx)?;
909
910 let mut t_out = group_key
911 .iter()
912 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
913 .collect::<Result<Vec<_>, _>>()?;
914
915 if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
916 return Err(TypeError::DisallowedDummy {
917 source: expr,
918 });
919 }
920
921 for agg in aggregates {
922 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
923 }
924
925 Ok(t_out)
926 }
927 TopK {
928 input,
929 group_key,
930 order_key,
931 limit: _,
932 offset: _,
933 monotonic: _,
934 expected_group_size: _,
935 } => {
936 let t_in = tc.typecheck(input, ctx)?;
937
938 for &k in group_key {
939 if k >= t_in.len() {
940 return Err(TypeError::BadTopKGroupKey {
941 source: expr,
942 k,
943 input_type: t_in,
944 });
945 }
946 }
947
948 for order in order_key {
949 if order.column >= t_in.len() {
950 return Err(TypeError::BadTopKOrdering {
951 source: expr,
952 order: order.clone(),
953 input_type: t_in,
954 });
955 }
956 }
957
958 Ok(t_in)
959 }
960 Negate { input } => tc.typecheck(input, ctx),
961 Threshold { input } => tc.typecheck(input, ctx),
962 Union { base, inputs } => {
963 let mut t_base = tc.typecheck(base, ctx)?;
964
965 for input in inputs {
966 let t_input = tc.typecheck(input, ctx)?;
967
968 let len_sub = t_base.len();
969 let len_sup = t_input.len();
970 if len_sub != len_sup {
971 return Err(TypeError::MismatchColumns {
972 source: expr,
973 got: t_base.clone(),
974 expected: t_input,
975 diffs: vec![ReprRelationTypeDifference::Length {
976 len_sub,
977 len_sup,
978 }],
979 message: "Union branches have different numbers of columns".to_string(),
980 });
981 }
982
983 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
984 let diffs = column_union(base_col, &input_col);
985 if !diffs.is_empty() {
986 return Err(TypeError::MismatchColumn {
987 source: expr,
988 got: input_col,
989 expected: base_col.clone(),
990 diffs,
991 message:
992 "couldn't compute union of column types in Union"
993 .to_string(),
994 });
995 }
996
997 }
998 }
999
1000 Ok(t_base)
1001 }
1002 Let { id, value, body } => {
1003 let t_value = tc.typecheck(value, ctx)?;
1004
1005 let binding = Id::Local(*id);
1006 if ctx.contains_key(&binding) {
1007 return Err(TypeError::Shadowing {
1008 source: expr,
1009 id: binding,
1010 });
1011 }
1012
1013 let mut body_ctx = ctx.clone();
1014 body_ctx.insert(Id::Local(*id), t_value);
1015
1016 tc.typecheck(body, &body_ctx)
1017 }
1018 LetRec { ids, values, body, limits: _ } => {
1019 if ids.len() != values.len() {
1020 return Err(TypeError::BadLetRecBindings { source: expr });
1021 }
1022
1023 let mut ctx = ctx.clone();
1026 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1028 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1029 }
1030
1031 for (id, value) in ids.iter().zip_eq(values.iter()) {
1032 let typ = tc.typecheck(value, &ctx)?;
1033
1034 let id = Id::Local(id.clone());
1035 if let Some(ctx_typ) = ctx.get_mut(&id) {
1036 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1037 let diffs = column_union(base_col, &input_col);
1039 if !diffs.is_empty() {
1040 return Err(TypeError::MismatchColumn {
1041 source: expr,
1042 got: input_col,
1043 expected: base_col.clone(),
1044 diffs,
1045 message:
1046 "couldn't compute union of column types in LetRec"
1047 .to_string(),
1048 })
1049 }
1050 }
1051 } else {
1052 ctx.insert(id, typ);
1054 }
1055 }
1056
1057 tc.typecheck(body, &ctx)
1058 }
1059 ArrangeBy { input, keys } => {
1060 let t_in = tc.typecheck(input, ctx)?;
1061
1062 for key in keys {
1063 for k in key {
1064 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1065 }
1066 }
1067
1068 Ok(t_in)
1069 }
1070 })
1071 }
1072
1073 fn collect_recursive_variable_types<'a>(
1077 &self,
1078 expr: &'a MirRelationExpr,
1079 ids: &[LocalId],
1080 ctx: &mut Context,
1081 ) -> Result<(), TypeError<'a>> {
1082 use MirRelationExpr::*;
1083
1084 self.checked_recur(|tc| {
1085 match expr {
1086 Get {
1087 id: Id::Local(id),
1088 typ,
1089 ..
1090 } => {
1091 if !ids.contains(id) {
1092 return Ok(());
1093 }
1094
1095 let id = Id::Local(id.clone());
1096 if let Some(ctx_typ) = ctx.get_mut(&id) {
1097 let typ = typ
1098 .column_types
1099 .iter()
1100 .map(ReprColumnType::from)
1101 .collect_vec();
1102
1103 if ctx_typ.len() != typ.len() {
1104 let diffs = relation_subtype_difference(&typ, ctx_typ);
1105
1106 return Err(TypeError::MismatchColumns {
1107 source: expr,
1108 got: typ,
1109 expected: ctx_typ.clone(),
1110 diffs,
1111 message: "environment and type annotation did not match"
1112 .to_string(),
1113 });
1114 }
1115
1116 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1117 let diffs = column_union(base_col, &input_col);
1118 if !diffs.is_empty() {
1119 return Err(TypeError::MismatchColumn {
1120 source: expr,
1121 got: input_col,
1122 expected: base_col.clone(),
1123 diffs,
1124 message:
1125 "couldn't compute union of column types in Get and context"
1126 .to_string(),
1127 });
1128 }
1129 }
1130 } else {
1131 ctx.insert(
1132 id,
1133 typ.column_types
1134 .iter()
1135 .map(ReprColumnType::from)
1136 .collect_vec(),
1137 );
1138 }
1139 }
1140 Get {
1141 id: Id::Global(..), ..
1142 }
1143 | Constant { .. } => (),
1144 Let { id, value, body } => {
1145 tc.collect_recursive_variable_types(value, ids, ctx)?;
1146
1147 if ids.contains(id) {
1149 return Err(TypeError::Shadowing {
1150 source: expr,
1151 id: Id::Local(*id),
1152 });
1153 }
1154
1155 tc.collect_recursive_variable_types(body, ids, ctx)?;
1156 }
1157 LetRec {
1158 ids: inner_ids,
1159 values,
1160 body,
1161 limits: _,
1162 } => {
1163 for inner_id in inner_ids {
1164 if ids.contains(inner_id) {
1165 return Err(TypeError::Shadowing {
1166 source: expr,
1167 id: Id::Local(*inner_id),
1168 });
1169 }
1170 }
1171
1172 for value in values {
1173 tc.collect_recursive_variable_types(value, ids, ctx)?;
1174 }
1175
1176 tc.collect_recursive_variable_types(body, ids, ctx)?;
1177 }
1178 Project { input, .. }
1179 | Map { input, .. }
1180 | FlatMap { input, .. }
1181 | Filter { input, .. }
1182 | Reduce { input, .. }
1183 | TopK { input, .. }
1184 | Negate { input }
1185 | Threshold { input }
1186 | ArrangeBy { input, .. } => {
1187 tc.collect_recursive_variable_types(input, ids, ctx)?;
1188 }
1189 Join { inputs, .. } => {
1190 for input in inputs {
1191 tc.collect_recursive_variable_types(input, ids, ctx)?;
1192 }
1193 }
1194 Union { base, inputs } => {
1195 tc.collect_recursive_variable_types(base, ids, ctx)?;
1196
1197 for input in inputs {
1198 tc.collect_recursive_variable_types(input, ids, ctx)?;
1199 }
1200 }
1201 }
1202
1203 Ok(())
1204 })
1205 }
1206
1207 fn typecheck_scalar<'a>(
1208 &self,
1209 expr: &'a MirScalarExpr,
1210 source: &'a MirRelationExpr,
1211 column_types: &[ReprColumnType],
1212 ) -> Result<ReprColumnType, TypeError<'a>> {
1213 use MirScalarExpr::*;
1214
1215 self.checked_recur(|tc| match expr {
1216 Column(i, _) => match column_types.get(*i) {
1217 Some(ty) => Ok(ty.clone()),
1218 None => Err(TypeError::NoSuchColumn {
1219 source,
1220 expr,
1221 col: *i,
1222 }),
1223 },
1224 Literal(row, typ) => {
1225 let typ = ReprColumnType::from(typ);
1226 if let Ok(row) = row {
1227 let datums = row.unpack();
1228
1229 if datums.len() != 1
1230 || (datums[0] != mz_repr::Datum::Dummy && !datums[0].is_instance_of(&typ))
1231 {
1232 return Err(TypeError::BadConstantRow {
1233 source,
1234 got: row.clone(),
1235 expected: vec![typ],
1236 });
1237 }
1238 }
1239
1240 Ok(typ)
1241 }
1242 CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1243 CallUnary { expr, func } => {
1244 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1245 let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1246 Ok(ReprColumnType::from(&typ_out))
1247 }
1248 CallBinary { expr1, expr2, func } => {
1249 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1250 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1251 let typ_out = func.output_type(
1252 SqlColumnType::from_repr(&typ_in1),
1253 SqlColumnType::from_repr(&typ_in2),
1254 );
1255 Ok(ReprColumnType::from(&typ_out))
1256 }
1257 CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1258 &func.output_type(
1259 exprs
1260 .iter()
1261 .map(|e| {
1262 tc.typecheck_scalar(e, source, column_types)
1263 .map(|typ| SqlColumnType::from_repr(&typ))
1264 })
1265 .collect::<Result<Vec<_>, TypeError>>()?,
1266 ),
1267 )),
1268 If { cond, then, els } => {
1269 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1270
1271 if cond_type.scalar_type != ReprScalarType::Bool {
1275 let sub = cond_type.scalar_type.clone();
1276
1277 return Err(TypeError::MismatchColumn {
1278 source,
1279 got: cond_type,
1280 expected: ReprColumnType {
1281 scalar_type: ReprScalarType::Bool,
1282 nullable: true,
1283 },
1284 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1285 sub,
1286 sup: ReprScalarType::Bool,
1287 }],
1288 message: "expected boolean condition".to_string(),
1289 });
1290 }
1291
1292 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1293 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1294
1295 let diffs = column_union(&mut then_type, &else_type);
1296 if !diffs.is_empty() {
1297 return Err(TypeError::MismatchColumn {
1298 source,
1299 got: then_type,
1300 expected: else_type,
1301 diffs,
1302 message: "couldn't compute union of column types for If".to_string(),
1303 });
1304 }
1305
1306 Ok(then_type)
1307 }
1308 })
1309 }
1310
1311 pub fn typecheck_aggregate<'a>(
1313 &self,
1314 expr: &'a AggregateExpr,
1315 source: &'a MirRelationExpr,
1316 column_types: &[ReprColumnType],
1317 ) -> Result<ReprColumnType, TypeError<'a>> {
1318 self.checked_recur(|tc| {
1319 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1320
1321 Ok(ReprColumnType::from(
1324 &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1325 ))
1326 })
1327 }
1328}
1329
1330macro_rules! type_error {
1334 ($severity:expr, $($arg:tt)+) => {{
1335 if $severity {
1336 soft_panic_or_log!($($arg)+);
1337 } else {
1338 ::tracing::debug!($($arg)+);
1339 }
1340 }}
1341}
1342
1343impl crate::Transform for Typecheck {
1344 fn name(&self) -> &'static str {
1345 "Typecheck"
1346 }
1347
1348 fn actually_perform_transform(
1349 &self,
1350 relation: &mut MirRelationExpr,
1351 transform_ctx: &mut crate::TransformCtx,
1352 ) -> Result<(), crate::TransformError> {
1353 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1354
1355 let expected = transform_ctx
1356 .global_id
1357 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1358
1359 if let Some(id) = transform_ctx.global_id {
1360 if self.disallow_new_globals
1361 && expected.is_none()
1362 && transform_ctx.global_id.is_some()
1363 && !id.is_transient()
1364 {
1365 type_error!(
1366 false, "type warning: new non-transient global id {id}\n{}",
1368 relation.pretty()
1369 );
1370 }
1371 }
1372
1373 let got = self.typecheck(relation, &typecheck_ctx);
1374
1375 let humanizer = mz_repr::explain::DummyHumanizer;
1376
1377 match (got, expected) {
1378 (Ok(got), Some(expected)) => {
1379 let id = transform_ctx.global_id.unwrap();
1380
1381 let diffs = relation_subtype_difference(expected, &got);
1383 if !diffs.is_empty() {
1384 let severity = diffs
1386 .iter()
1387 .any(|diff| diff.clone().ignore_nullability().is_some());
1388
1389 let err = TypeError::MismatchColumns {
1390 source: relation,
1391 got,
1392 expected: expected.clone(),
1393 diffs,
1394 message: format!(
1395 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1396 ),
1397 };
1398
1399 type_error!(severity, "type error in known global id {id}:\n{err}");
1400 }
1401 }
1402 (Ok(got), None) => {
1403 if let Some(id) = transform_ctx.global_id {
1404 typecheck_ctx.insert(Id::Global(id), got);
1405 }
1406 }
1407 (Err(err), _) => {
1408 let (expected, binding) = match expected {
1409 Some(expected) => {
1410 let id = transform_ctx.global_id.unwrap();
1411 (
1412 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1413 format!("known global id {id}"),
1414 )
1415 }
1416 None => ("".to_string(), "transient query".to_string()),
1417 };
1418
1419 type_error!(
1420 true, "type error in {binding}:\n{err}\n{expected}{}",
1422 relation.pretty()
1423 );
1424 }
1425 }
1426
1427 Ok(())
1428 }
1429}
1430
1431pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1433where
1434 H: ExprHumanizer,
1435{
1436 let mut s = String::with_capacity(2 + 3 * cols.len());
1437
1438 s.push('(');
1439
1440 let mut it = cols.iter().peekable();
1441 while let Some(col) = it.next() {
1442 s.push_str(&humanizer.humanize_column_type_repr(col, false));
1443
1444 if it.peek().is_some() {
1445 s.push_str(", ");
1446 }
1447 }
1448
1449 s.push(')');
1450
1451 s
1452}
1453
1454impl ReprRelationTypeDifference {
1455 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1459 where
1460 H: ExprHumanizer,
1461 {
1462 use ReprRelationTypeDifference::*;
1463 match self {
1464 Length { len_sub, len_sup } => {
1465 writeln!(
1466 f,
1467 " number of columns do not match ({len_sub} != {len_sup})"
1468 )
1469 }
1470 Column { col, diff } => {
1471 writeln!(f, " column {col} differs:")?;
1472 diff.humanize(4, h, f)
1473 }
1474 }
1475 }
1476}
1477
1478impl ReprColumnTypeDifference {
1479 pub fn humanize<H>(
1481 &self,
1482 indent: usize,
1483 h: &H,
1484 f: &mut std::fmt::Formatter<'_>,
1485 ) -> std::fmt::Result
1486 where
1487 H: ExprHumanizer,
1488 {
1489 use ReprColumnTypeDifference::*;
1490
1491 write!(f, "{:indent$}", "")?;
1493
1494 match self {
1495 NotSubtype { sub, sup } => {
1496 let sub = h.humanize_scalar_type_repr(sub, false);
1497 let sup = h.humanize_scalar_type_repr(sup, false);
1498
1499 writeln!(f, "{sub} is a not a subtype of {sup}")
1500 }
1501 Nullability { sub, sup } => {
1502 let sub = h.humanize_column_type_repr(sub, false);
1503 let sup = h.humanize_column_type_repr(sup, false);
1504
1505 writeln!(f, "{sub} is nullable but {sup} is not")
1506 }
1507 ElementType { ctor, element_type } => {
1508 writeln!(f, "{ctor} element types differ:")?;
1509
1510 element_type.humanize(indent + 2, h, f)
1511 }
1512 RecordMissingFields { missing } => {
1513 write!(f, "missing column fields:")?;
1514 for col in missing {
1515 write!(f, " {col}")?;
1516 }
1517 f.write_char('\n')
1518 }
1519 RecordFields { fields } => {
1520 writeln!(f, "{} record fields differ:", fields.len())?;
1521
1522 for (i, diff) in fields.iter().enumerate() {
1523 writeln!(f, "{:indent$} field {i}:", "")?;
1524 diff.humanize(indent + 4, h, f)?;
1525 }
1526 Ok(())
1527 }
1528 }
1529 }
1530}
1531
1532#[allow(missing_debug_implementations)]
1534pub struct TypeErrorHumanizer<'a, 'b, H>
1535where
1536 H: ExprHumanizer,
1537{
1538 err: &'a TypeError<'a>,
1539 humanizer: &'b H,
1540}
1541
1542impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1543where
1544 H: ExprHumanizer,
1545{
1546 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1548 Self { err, humanizer }
1549 }
1550}
1551
1552impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1553where
1554 H: ExprHumanizer,
1555{
1556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1557 self.err.humanize(self.humanizer, f)
1558 }
1559}
1560
1561impl<'a> std::fmt::Display for TypeError<'a> {
1562 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1563 TypeErrorHumanizer {
1564 err: self,
1565 humanizer: &DummyHumanizer,
1566 }
1567 .fmt(f)
1568 }
1569}
1570
1571impl<'a> TypeError<'a> {
1572 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1574 use TypeError::*;
1575 match self {
1576 Unbound { source, .. }
1577 | NoSuchColumn { source, .. }
1578 | MismatchColumn { source, .. }
1579 | MismatchColumns { source, .. }
1580 | BadConstantRow { source, .. }
1581 | BadProject { source, .. }
1582 | BadJoinEquivalence { source, .. }
1583 | BadTopKGroupKey { source, .. }
1584 | BadTopKOrdering { source, .. }
1585 | BadLetRecBindings { source }
1586 | Shadowing { source, .. }
1587 | DisallowedDummy { source, .. } => Some(source),
1588 Recursion { .. } => None,
1589 }
1590 }
1591
1592 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1593 where
1594 H: ExprHumanizer,
1595 {
1596 if let Some(source) = self.source() {
1597 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1598 }
1599
1600 use TypeError::*;
1601 match self {
1602 Unbound { source: _, id, typ } => {
1603 let typ = columns_pretty(&typ.column_types, humanizer);
1604 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1605 }
1606 NoSuchColumn {
1607 source: _,
1608 expr,
1609 col,
1610 } => writeln!(f, "{expr} references non-existent column {col}")?,
1611 MismatchColumn {
1612 source: _,
1613 got,
1614 expected,
1615 diffs,
1616 message,
1617 } => {
1618 let got = humanizer.humanize_column_type_repr(got, false);
1619 let expected = humanizer.humanize_column_type_repr(expected, false);
1620 writeln!(
1621 f,
1622 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1623 )?;
1624
1625 for diff in diffs {
1626 diff.humanize(2, humanizer, f)?;
1627 }
1628 }
1629 MismatchColumns {
1630 source: _,
1631 got,
1632 expected,
1633 diffs,
1634 message,
1635 } => {
1636 let got = columns_pretty(got, humanizer);
1637 let expected = columns_pretty(expected, humanizer);
1638
1639 writeln!(
1640 f,
1641 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1642 )?;
1643
1644 for diff in diffs {
1645 diff.humanize(humanizer, f)?;
1646 }
1647 }
1648 BadConstantRow {
1649 source: _,
1650 got,
1651 expected,
1652 } => {
1653 let expected = columns_pretty(expected, humanizer);
1654
1655 writeln!(
1656 f,
1657 "bad constant row\n got {got}\nexpected row of type {expected}"
1658 )?
1659 }
1660 BadProject {
1661 source: _,
1662 got,
1663 input_type,
1664 } => {
1665 let input_type = columns_pretty(input_type, humanizer);
1666
1667 writeln!(
1668 f,
1669 "projection of non-existant columns {got:?} from type {input_type}"
1670 )?
1671 }
1672 BadJoinEquivalence {
1673 source: _,
1674 got,
1675 message,
1676 } => {
1677 let got = columns_pretty(got, humanizer);
1678
1679 writeln!(f, "bad join equivalence {got}: {message}")?
1680 }
1681 BadTopKGroupKey {
1682 source: _,
1683 k,
1684 input_type,
1685 } => {
1686 let input_type = columns_pretty(input_type, humanizer);
1687
1688 writeln!(
1689 f,
1690 "TopK group key component references invalid column {k} in columns: {input_type}"
1691 )?
1692 }
1693 BadTopKOrdering {
1694 source: _,
1695 order,
1696 input_type,
1697 } => {
1698 let col = order.column;
1699 let num_cols = input_type.len();
1700 let are = if num_cols == 1 { "is" } else { "are" };
1701 let s = if num_cols == 1 { "" } else { "s" };
1702 let input_type = columns_pretty(input_type, humanizer);
1703
1704 let mode = HumanizedExplain::new(false);
1706 let order = mode.expr(order, None);
1707
1708 writeln!(
1709 f,
1710 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
1711 )?
1712 }
1713 BadLetRecBindings { source: _ } => {
1714 writeln!(f, "LetRec ids and definitions don't line up")?
1715 }
1716 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
1717 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
1718 Recursion { error } => writeln!(f, "{error}")?,
1719 }
1720
1721 Ok(())
1722 }
1723}