1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
25use mz_repr::{ColumnName, Row, SqlColumnType, SqlRelationType, SqlScalarBaseType, SqlScalarType};
26
27pub type SharedContext = Arc<Mutex<Context>>;
32
33pub fn empty_context() -> SharedContext {
35 Arc::new(Mutex::new(BTreeMap::new()))
36}
37
38#[derive(Clone, Debug)]
43pub enum TypeError<'a> {
44 Unbound {
46 source: &'a MirRelationExpr,
48 id: Id,
50 typ: SqlRelationType,
52 },
53 NoSuchColumn {
55 source: &'a MirRelationExpr,
57 expr: &'a MirScalarExpr,
59 col: usize,
61 },
62 MismatchColumn {
64 source: &'a MirRelationExpr,
66 got: SqlColumnType,
68 expected: SqlColumnType,
70 diffs: Vec<SqlColumnTypeDifference>,
72 message: String,
74 },
75 MismatchColumns {
77 source: &'a MirRelationExpr,
79 got: Vec<SqlColumnType>,
81 expected: Vec<SqlColumnType>,
83 diffs: Vec<SqlRelationTypeDifference>,
85 message: String,
87 },
88 BadConstantRow {
90 source: &'a MirRelationExpr,
92 got: Row,
94 expected: Vec<SqlColumnType>,
96 },
98 BadProject {
100 source: &'a MirRelationExpr,
102 got: Vec<usize>,
104 input_type: Vec<SqlColumnType>,
106 },
107 BadJoinEquivalence {
109 source: &'a MirRelationExpr,
111 got: Vec<SqlColumnType>,
113 message: String,
115 },
116 BadTopKGroupKey {
118 source: &'a MirRelationExpr,
120 k: usize,
122 input_type: Vec<SqlColumnType>,
124 },
125 BadTopKOrdering {
127 source: &'a MirRelationExpr,
129 order: ColumnOrder,
131 input_type: Vec<SqlColumnType>,
133 },
134 BadLetRecBindings {
136 source: &'a MirRelationExpr,
138 },
139 Shadowing {
141 source: &'a MirRelationExpr,
143 id: Id,
145 },
146 Recursion {
148 error: RecursionLimitError,
150 },
151 DisallowedDummy {
153 source: &'a MirRelationExpr,
155 },
156}
157
158impl<'a> From<RecursionLimitError> for TypeError<'a> {
159 fn from(error: RecursionLimitError) -> Self {
160 TypeError::Recursion { error }
161 }
162}
163
164type Context = BTreeMap<Id, Vec<SqlColumnType>>;
165
166#[derive(Clone, Debug, Hash)]
170pub enum SqlRelationTypeDifference {
171 Length {
173 len_sub: usize,
175 len_sup: usize,
177 },
178 Column {
180 col: usize,
182 diff: SqlColumnTypeDifference,
184 },
185}
186
187#[derive(Clone, Debug, Hash)]
192pub enum SqlColumnTypeDifference {
193 NotSubtype {
195 sub: SqlScalarType,
197 sup: SqlScalarType,
199 },
200 Nullability {
202 sub: SqlColumnType,
204 sup: SqlColumnType,
206 },
207 ElementType {
209 ctor: String,
211 element_type: Box<SqlColumnTypeDifference>,
213 },
214 RecordMissingFields {
216 missing: Vec<ColumnName>,
218 },
219 RecordFields {
221 fields: Vec<(ColumnName, SqlColumnTypeDifference)>,
223 },
224}
225
226impl SqlRelationTypeDifference {
227 pub fn ignore_nullability(self) -> Option<Self> {
231 use SqlRelationTypeDifference::*;
232
233 match self {
234 Length { .. } => Some(self),
235 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
236 }
237 }
238}
239
240impl SqlColumnTypeDifference {
241 pub fn ignore_nullability(self) -> Option<Self> {
245 use SqlColumnTypeDifference::*;
246
247 match self {
248 Nullability { .. } => None,
249 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
250 ElementType { ctor, element_type } => {
251 element_type
252 .ignore_nullability()
253 .map(|element_type| ElementType {
254 ctor,
255 element_type: Box::new(element_type),
256 })
257 }
258 RecordFields { fields } => {
259 let fields = fields
260 .into_iter()
261 .flat_map(|(col, diff)| diff.ignore_nullability().map(|diff| (col, diff)))
262 .collect::<Vec<_>>();
263
264 if fields.is_empty() {
265 None
266 } else {
267 Some(RecordFields { fields })
268 }
269 }
270 }
271 }
272}
273
274pub fn relation_subtype_difference(
278 sub: &[SqlColumnType],
279 sup: &[SqlColumnType],
280) -> Vec<SqlRelationTypeDifference> {
281 let mut diffs = Vec::new();
282
283 if sub.len() != sup.len() {
284 diffs.push(SqlRelationTypeDifference::Length {
285 len_sub: sub.len(),
286 len_sup: sup.len(),
287 });
288
289 return diffs;
291 }
292
293 diffs.extend(
294 sub.iter()
295 .zip_eq(sup.iter())
296 .enumerate()
297 .flat_map(|(col, (sub_ty, sup_ty))| {
298 column_subtype_difference(sub_ty, sup_ty)
299 .into_iter()
300 .map(move |diff| SqlRelationTypeDifference::Column { col, diff })
301 }),
302 );
303
304 diffs
305}
306
307pub fn column_subtype_difference(
311 sub: &SqlColumnType,
312 sup: &SqlColumnType,
313) -> Vec<SqlColumnTypeDifference> {
314 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
315
316 if sub.nullable && !sup.nullable {
317 diffs.push(SqlColumnTypeDifference::Nullability {
318 sub: sub.clone(),
319 sup: sup.clone(),
320 });
321 }
322
323 diffs
324}
325
326pub fn scalar_subtype_difference(
330 sub: &SqlScalarType,
331 sup: &SqlScalarType,
332) -> Vec<SqlColumnTypeDifference> {
333 use SqlScalarType::*;
334
335 let mut diffs = Vec::new();
336
337 match (sub, sup) {
338 (
339 List {
340 element_type: sub_elt,
341 ..
342 },
343 List {
344 element_type: sup_elt,
345 ..
346 },
347 )
348 | (
349 Map {
350 value_type: sub_elt,
351 ..
352 },
353 Map {
354 value_type: sup_elt,
355 ..
356 },
357 )
358 | (
359 Range {
360 element_type: sub_elt,
361 ..
362 },
363 Range {
364 element_type: sup_elt,
365 ..
366 },
367 )
368 | (Array(sub_elt), Array(sup_elt)) => {
369 let ctor = format!("{:?}", SqlScalarBaseType::from(sub));
370 diffs.extend(
371 scalar_subtype_difference(sub_elt, sup_elt)
372 .into_iter()
373 .map(|diff| SqlColumnTypeDifference::ElementType {
374 ctor: ctor.clone(),
375 element_type: Box::new(diff),
376 }),
377 );
378 }
379 (
380 Record {
381 fields: sub_fields, ..
382 },
383 Record {
384 fields: sup_fields, ..
385 },
386 ) => {
387 let sub = sub_fields
388 .iter()
389 .map(|(sub_field, sub_ty)| (sub_field.clone(), sub_ty))
390 .collect::<BTreeMap<_, _>>();
391
392 let mut missing = Vec::new();
393 let mut field_diffs = Vec::new();
394 for (sup_field, sup_ty) in sup_fields {
395 if let Some(sub_ty) = sub.get(sup_field) {
396 let diff = column_subtype_difference(sub_ty, sup_ty);
397
398 if !diff.is_empty() {
399 field_diffs.push((sup_field.clone(), diff));
400 }
401 } else {
402 missing.push(sup_field.clone());
403 }
404 }
405 }
406 (_, _) => {
407 if SqlScalarBaseType::from(sub) != SqlScalarBaseType::from(sup) {
409 diffs.push(SqlColumnTypeDifference::NotSubtype {
410 sub: sub.clone(),
411 sup: sup.clone(),
412 })
413 }
414 }
415 };
416
417 diffs
418}
419
420pub fn is_subtype_of(sub: &[SqlColumnType], sup: &[SqlColumnType]) -> bool {
425 if sub.len() != sup.len() {
426 return false;
427 }
428
429 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
430 (!known.nullable || got.nullable) && got.scalar_type.base_eq(&known.scalar_type)
431 })
432}
433
434#[derive(Debug)]
436pub struct Typecheck {
437 ctx: SharedContext,
439 disallow_new_globals: bool,
441 strict_join_equivalences: bool,
443 disallow_dummy: bool,
445 recursion_guard: RecursionGuard,
447}
448
449impl CheckedRecursion for Typecheck {
450 fn recursion_guard(&self) -> &RecursionGuard {
451 &self.recursion_guard
452 }
453}
454
455impl Typecheck {
456 pub fn new(ctx: SharedContext) -> Self {
458 Self {
459 ctx,
460 disallow_new_globals: false,
461 strict_join_equivalences: false,
462 disallow_dummy: false,
463 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
464 }
465 }
466
467 pub fn disallow_new_globals(mut self) -> Self {
471 self.disallow_new_globals = true;
472 self
473 }
474
475 pub fn strict_join_equivalences(mut self) -> Self {
479 self.strict_join_equivalences = true;
480
481 self
482 }
483
484 pub fn disallow_dummy(mut self) -> Self {
486 self.disallow_dummy = true;
487 self
488 }
489
490 pub fn typecheck<'a>(
501 &self,
502 expr: &'a MirRelationExpr,
503 ctx: &Context,
504 ) -> Result<Vec<SqlColumnType>, TypeError<'a>> {
505 use MirRelationExpr::*;
506
507 self.checked_recur(|tc| match expr {
508 Constant { typ, rows } => {
509 if let Ok(rows) = rows {
510 for (row, _id) in rows {
511 let datums = row.unpack();
512
513 if datums.len() != typ.column_types.len() {
515 return Err(TypeError::BadConstantRow {
516 source: expr,
517 got: row.clone(),
518 expected: typ.column_types.clone(),
519 });
520 }
521
522 if datums
524 .iter()
525 .zip_eq(typ.column_types.iter())
526 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of_sql(ty))
527 {
528 return Err(TypeError::BadConstantRow {
529 source: expr,
530 got: row.clone(),
531 expected: typ.column_types.clone(),
532 });
533 }
534
535 if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
536 return Err(TypeError::DisallowedDummy {
537 source: expr,
538 });
539 }
540 }
541 }
542
543 Ok(typ.column_types.clone())
544 }
545 Get { typ, id, .. } => {
546 if let Id::Global(_global_id) = id {
547 if !ctx.contains_key(id) {
548 return Ok(typ.column_types.clone());
550 }
551 }
552
553 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
554 source: expr,
555 id: id.clone(),
556 typ: typ.clone(),
557 })?;
558
559 let diffs = relation_subtype_difference(&typ.column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
561
562 if !diffs.is_empty() {
563 return Err(TypeError::MismatchColumns {
564 source: expr,
565 got: typ.column_types.clone(),
566 expected: ctx_typ.clone(),
567 diffs,
568 message: "annotation did not match context type".into(),
569 });
570 }
571
572 Ok(typ.column_types.clone())
573 }
574 Project { input, outputs } => {
575 let t_in = tc.typecheck(input, ctx)?;
576
577 for x in outputs {
578 if *x >= t_in.len() {
579 return Err(TypeError::BadProject {
580 source: expr,
581 got: outputs.clone(),
582 input_type: t_in,
583 });
584 }
585 }
586
587 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
588 }
589 Map { input, scalars } => {
590 let mut t_in = tc.typecheck(input, ctx)?;
591
592 for scalar_expr in scalars.iter() {
593 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
594
595 if self.disallow_dummy && scalar_expr.contains_dummy() {
596 return Err(TypeError::DisallowedDummy {
597 source: expr,
598 });
599 }
600 }
601
602 Ok(t_in)
603 }
604 FlatMap { input, func, exprs } => {
605 let mut t_in = tc.typecheck(input, ctx)?;
606
607 let mut t_exprs = Vec::with_capacity(exprs.len());
608 for scalar_expr in exprs {
609 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
610
611 if self.disallow_dummy && scalar_expr.contains_dummy() {
612 return Err(TypeError::DisallowedDummy {
613 source: expr,
614 });
615 }
616 }
617 let t_out = func.output_type().column_types;
620
621 t_in.extend(t_out);
623 Ok(t_in)
624 }
625 Filter { input, predicates } => {
626 let mut t_in = tc.typecheck(input, ctx)?;
627
628 for column in non_nullable_columns(predicates) {
631 t_in[column].nullable = false;
632 }
633
634 for scalar_expr in predicates {
635 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
636
637 if t.scalar_type != SqlScalarType::Bool {
641 let sub = t.scalar_type.clone();
642
643 return Err(TypeError::MismatchColumn {
644 source: expr,
645 got: t,
646 expected: SqlColumnType {
647 scalar_type: SqlScalarType::Bool,
648 nullable: true,
649 },
650 diffs: vec![SqlColumnTypeDifference::NotSubtype { sub, sup: SqlScalarType::Bool }],
651 message: "expected boolean condition".into(),
652 });
653 }
654
655 if self.disallow_dummy && scalar_expr.contains_dummy() {
656 return Err(TypeError::DisallowedDummy {
657 source: expr,
658 });
659 }
660 }
661
662 Ok(t_in)
663 }
664 Join {
665 inputs,
666 equivalences,
667 implementation,
668 } => {
669 let mut t_in_global = Vec::new();
670 let mut t_in_local = vec![Vec::new(); inputs.len()];
671
672 for (i, input) in inputs.iter().enumerate() {
673 let input_t = tc.typecheck(input, ctx)?;
674 t_in_global.extend(input_t.clone());
675 t_in_local[i] = input_t;
676 }
677
678 for eq_class in equivalences {
679 let mut t_exprs: Vec<SqlColumnType> = Vec::with_capacity(eq_class.len());
680
681 let mut all_nullable = true;
682
683 for scalar_expr in eq_class {
684 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
686
687 if !t_expr.nullable {
688 all_nullable = false;
689 }
690
691 if let Some(t_first) = t_exprs.get(0) {
692 let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
693 if !diffs.is_empty() {
694 return Err(TypeError::MismatchColumn {
695 source: expr,
696 got: t_expr,
697 expected: t_first.clone(),
698 diffs,
699 message: "equivalence class members have different scalar types".into(),
700 });
701 }
702
703 if self.strict_join_equivalences {
707 if t_expr.nullable != t_first.nullable {
708 let sub = t_expr.clone();
709 let sup = t_first.clone();
710
711 let err = TypeError::MismatchColumn {
712 source: expr,
713 got: t_expr.clone(),
714 expected: t_first.clone(),
715 diffs: vec![SqlColumnTypeDifference::Nullability { sub, sup }],
716 message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
717 };
718
719 ::tracing::debug!("{err}");
721 }
722 }
723 }
724
725 if self.disallow_dummy && scalar_expr.contains_dummy() {
726 return Err(TypeError::DisallowedDummy {
727 source: expr,
728 });
729 }
730
731 t_exprs.push(t_expr);
732 }
733
734 if self.strict_join_equivalences && all_nullable {
735 let err = TypeError::BadJoinEquivalence {
736 source: expr,
737 got: t_exprs,
738 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
739 };
740
741 ::tracing::debug!("{err}");
743 }
744 }
745
746 match implementation {
748 JoinImplementation::Differential((start_idx, first_key, _), others) => {
749 if let Some(key) = first_key {
750 for k in key {
751 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
752 }
753 }
754
755 for (idx, key, _) in others {
756 for k in key {
757 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
758 }
759 }
760 }
761 JoinImplementation::DeltaQuery(plans) => {
762 for plan in plans {
763 for (idx, key, _) in plan {
764 for k in key {
765 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
766 }
767 }
768 }
769 }
770 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
771 let typ: Vec<SqlColumnType> = key
772 .iter()
773 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
774 .collect::<Result<Vec<SqlColumnType>, TypeError>>()?;
775
776 for row in consts {
777 let datums = row.unpack();
778
779 if datums.len() != typ.len() {
781 return Err(TypeError::BadConstantRow {
782 source: expr,
783 got: row.clone(),
784 expected: typ,
785 });
786 }
787
788 if datums
790 .iter()
791 .zip_eq(typ.iter())
792 .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of_sql(ty))
793 {
794 return Err(TypeError::BadConstantRow {
795 source: expr,
796 got: row.clone(),
797 expected: typ,
798 });
799 }
800 }
801 }
802 JoinImplementation::Unimplemented => (),
803 }
804
805 Ok(t_in_global)
806 }
807 Reduce {
808 input,
809 group_key,
810 aggregates,
811 monotonic: _,
812 expected_group_size: _,
813 } => {
814 let t_in = tc.typecheck(input, ctx)?;
815
816 let mut t_out = group_key
817 .iter()
818 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
819 .collect::<Result<Vec<_>, _>>()?;
820
821 if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
822 return Err(TypeError::DisallowedDummy {
823 source: expr,
824 });
825 }
826
827 for agg in aggregates {
828 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
829 }
830
831 Ok(t_out)
832 }
833 TopK {
834 input,
835 group_key,
836 order_key,
837 limit: _,
838 offset: _,
839 monotonic: _,
840 expected_group_size: _,
841 } => {
842 let t_in = tc.typecheck(input, ctx)?;
843
844 for &k in group_key {
845 if k >= t_in.len() {
846 return Err(TypeError::BadTopKGroupKey {
847 source: expr,
848 k,
849 input_type: t_in,
850 });
851 }
852 }
853
854 for order in order_key {
855 if order.column >= t_in.len() {
856 return Err(TypeError::BadTopKOrdering {
857 source: expr,
858 order: order.clone(),
859 input_type: t_in,
860 });
861 }
862 }
863
864 Ok(t_in)
865 }
866 Negate { input } => tc.typecheck(input, ctx),
867 Threshold { input } => tc.typecheck(input, ctx),
868 Union { base, inputs } => {
869 let mut t_base = tc.typecheck(base, ctx)?;
870
871 for input in inputs {
872 let t_input = tc.typecheck(input, ctx)?;
873
874 let len_sub = t_base.len();
875 let len_sup = t_input.len();
876 if len_sub != len_sup {
877 return Err(TypeError::MismatchColumns {
878 source: expr,
879 got: t_base.clone(),
880 expected: t_input,
881 diffs: vec![SqlRelationTypeDifference::Length {
882 len_sub,
883 len_sup,
884 }],
885 message: "union branches have different numbers of columns".into(),
886 });
887 }
888
889 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
890 *base_col =
891 base_col
892 .union(&input_col)
893 .map_err(|e| {
894 let base_col = base_col.clone();
895 let diffs = column_subtype_difference(&base_col, &input_col);
896
897 TypeError::MismatchColumn {
898 source: expr,
899 got: input_col,
900 expected: base_col,
901 diffs,
902 message: format!(
903 "couldn't compute union of column types in union: {e}"
904 ),
905 }
906 })?;
907 }
908 }
909
910 Ok(t_base)
911 }
912 Let { id, value, body } => {
913 let t_value = tc.typecheck(value, ctx)?;
914
915 let binding = Id::Local(*id);
916 if ctx.contains_key(&binding) {
917 return Err(TypeError::Shadowing {
918 source: expr,
919 id: binding,
920 });
921 }
922
923 let mut body_ctx = ctx.clone();
924 body_ctx.insert(Id::Local(*id), t_value);
925
926 tc.typecheck(body, &body_ctx)
927 }
928 LetRec { ids, values, body, limits: _ } => {
929 if ids.len() != values.len() {
930 return Err(TypeError::BadLetRecBindings { source: expr });
931 }
932
933 let mut ctx = ctx.clone();
936 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
938 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
939 }
940
941 for (id, value) in ids.iter().zip_eq(values.iter()) {
942 let typ = tc.typecheck(value, &ctx)?;
943
944 let id = Id::Local(id.clone());
945 if let Some(ctx_typ) = ctx.get_mut(&id) {
946 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
947 *base_col = base_col.union(&input_col).map_err(|e| {
948 let base_col = base_col.clone();
949 let diffs = column_subtype_difference(&base_col, &input_col);
950
951 TypeError::MismatchColumn {
952 source: expr,
953 got: input_col,
954 expected: base_col,
955 diffs,
956 message: format!(
957 "couldn't compute union of column types in let rec: {e}"
958 ),
959 }
960 })?;
961 }
962 } else {
963 ctx.insert(id, typ);
965 }
966 }
967
968 tc.typecheck(body, &ctx)
969 }
970 ArrangeBy { input, keys } => {
971 let t_in = tc.typecheck(input, ctx)?;
972
973 for key in keys {
974 for k in key {
975 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
976 }
977 }
978
979 Ok(t_in)
980 }
981 })
982 }
983
984 fn collect_recursive_variable_types<'a>(
988 &self,
989 expr: &'a MirRelationExpr,
990 ids: &[LocalId],
991 ctx: &mut Context,
992 ) -> Result<(), TypeError<'a>> {
993 use MirRelationExpr::*;
994
995 self.checked_recur(|tc| {
996 match expr {
997 Get {
998 id: Id::Local(id),
999 typ,
1000 ..
1001 } => {
1002 if !ids.contains(id) {
1003 return Ok(());
1004 }
1005
1006 let id = Id::Local(id.clone());
1007 if let Some(ctx_typ) = ctx.get_mut(&id) {
1008 for (base_col, input_col) in
1009 ctx_typ.iter_mut().zip_eq(typ.column_types.iter())
1010 {
1011 *base_col = base_col.union(input_col).map_err(|e| {
1012 let base_col = base_col.clone();
1013 let diffs = column_subtype_difference(&base_col, input_col);
1014
1015 TypeError::MismatchColumn {
1016 source: expr,
1017 got: input_col.clone(),
1018 expected: base_col,
1019 diffs,
1020 message: format!(
1021 "couldn't compute union of collected column types: {}",
1022 e
1023 ),
1024 }
1025 })?;
1026 }
1027 } else {
1028 ctx.insert(id, typ.column_types.clone());
1029 }
1030 }
1031 Get {
1032 id: Id::Global(..), ..
1033 }
1034 | Constant { .. } => (),
1035 Let { id, value, body } => {
1036 tc.collect_recursive_variable_types(value, ids, ctx)?;
1037
1038 if ids.contains(id) {
1040 return Err(TypeError::Shadowing {
1041 source: expr,
1042 id: Id::Local(*id),
1043 });
1044 }
1045
1046 tc.collect_recursive_variable_types(body, ids, ctx)?;
1047 }
1048 LetRec {
1049 ids: inner_ids,
1050 values,
1051 body,
1052 limits: _,
1053 } => {
1054 for inner_id in inner_ids {
1055 if ids.contains(inner_id) {
1056 return Err(TypeError::Shadowing {
1057 source: expr,
1058 id: Id::Local(*inner_id),
1059 });
1060 }
1061 }
1062
1063 for value in values {
1064 tc.collect_recursive_variable_types(value, ids, ctx)?;
1065 }
1066
1067 tc.collect_recursive_variable_types(body, ids, ctx)?;
1068 }
1069 Project { input, .. }
1070 | Map { input, .. }
1071 | FlatMap { input, .. }
1072 | Filter { input, .. }
1073 | Reduce { input, .. }
1074 | TopK { input, .. }
1075 | Negate { input }
1076 | Threshold { input }
1077 | ArrangeBy { input, .. } => {
1078 tc.collect_recursive_variable_types(input, ids, ctx)?;
1079 }
1080 Join { inputs, .. } => {
1081 for input in inputs {
1082 tc.collect_recursive_variable_types(input, ids, ctx)?;
1083 }
1084 }
1085 Union { base, inputs } => {
1086 tc.collect_recursive_variable_types(base, ids, ctx)?;
1087
1088 for input in inputs {
1089 tc.collect_recursive_variable_types(input, ids, ctx)?;
1090 }
1091 }
1092 }
1093
1094 Ok(())
1095 })
1096 }
1097
1098 fn typecheck_scalar<'a>(
1099 &self,
1100 expr: &'a MirScalarExpr,
1101 source: &'a MirRelationExpr,
1102 column_types: &[SqlColumnType],
1103 ) -> Result<SqlColumnType, TypeError<'a>> {
1104 use MirScalarExpr::*;
1105
1106 self.checked_recur(|tc| match expr {
1107 Column(i, _) => match column_types.get(*i) {
1108 Some(ty) => Ok(ty.clone()),
1109 None => Err(TypeError::NoSuchColumn {
1110 source,
1111 expr,
1112 col: *i,
1113 }),
1114 },
1115 Literal(row, typ) => {
1116 if let Ok(row) = row {
1117 let datums = row.unpack();
1118
1119 if datums.len() != 1
1120 || (datums[0] != mz_repr::Datum::Dummy
1121 && !datums[0].is_instance_of_sql(typ))
1122 {
1123 return Err(TypeError::BadConstantRow {
1124 source,
1125 got: row.clone(),
1126 expected: vec![typ.clone()],
1127 });
1128 }
1129 }
1130
1131 Ok(typ.clone())
1132 }
1133 CallUnmaterializable(func) => Ok(func.output_type()),
1134 CallUnary { expr, func } => {
1135 Ok(func.output_type(tc.typecheck_scalar(expr, source, column_types)?))
1136 }
1137 CallBinary { expr1, expr2, func } => Ok(func.output_type(
1138 tc.typecheck_scalar(expr1, source, column_types)?,
1139 tc.typecheck_scalar(expr2, source, column_types)?,
1140 )),
1141 CallVariadic { exprs, func } => Ok(func.output_type(
1142 exprs
1143 .iter()
1144 .map(|e| tc.typecheck_scalar(e, source, column_types))
1145 .collect::<Result<Vec<_>, TypeError>>()?,
1146 )),
1147 If { cond, then, els } => {
1148 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1149
1150 if cond_type.scalar_type != SqlScalarType::Bool {
1154 let sub = cond_type.scalar_type.clone();
1155
1156 return Err(TypeError::MismatchColumn {
1157 source,
1158 got: cond_type,
1159 expected: SqlColumnType {
1160 scalar_type: SqlScalarType::Bool,
1161 nullable: true,
1162 },
1163 diffs: vec![SqlColumnTypeDifference::NotSubtype {
1164 sub,
1165 sup: SqlScalarType::Bool,
1166 }],
1167 message: "expected boolean condition".into(),
1168 });
1169 }
1170
1171 let then_type = tc.typecheck_scalar(then, source, column_types)?;
1172 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1173 then_type.union(&else_type).map_err(|e| {
1174 let diffs = column_subtype_difference(&then_type, &else_type);
1175
1176 TypeError::MismatchColumn {
1177 source,
1178 got: then_type,
1179 expected: else_type,
1180 diffs,
1181 message: format!("couldn't compute union of column types for if: {e}"),
1182 }
1183 })
1184 }
1185 })
1186 }
1187
1188 pub fn typecheck_aggregate<'a>(
1190 &self,
1191 expr: &'a AggregateExpr,
1192 source: &'a MirRelationExpr,
1193 column_types: &[SqlColumnType],
1194 ) -> Result<SqlColumnType, TypeError<'a>> {
1195 self.checked_recur(|tc| {
1196 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1197
1198 Ok(expr.func.output_type(t_in))
1201 })
1202 }
1203}
1204
1205macro_rules! type_error {
1209 ($severity:expr, $($arg:tt)+) => {{
1210 if $severity {
1211 soft_panic_or_log!($($arg)+);
1212 } else {
1213 ::tracing::debug!($($arg)+);
1214 }
1215 }}
1216}
1217
1218impl crate::Transform for Typecheck {
1219 fn name(&self) -> &'static str {
1220 "Typecheck"
1221 }
1222
1223 fn actually_perform_transform(
1224 &self,
1225 relation: &mut MirRelationExpr,
1226 transform_ctx: &mut crate::TransformCtx,
1227 ) -> Result<(), crate::TransformError> {
1228 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1229
1230 let expected = transform_ctx
1231 .global_id
1232 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1233
1234 if let Some(id) = transform_ctx.global_id {
1235 if self.disallow_new_globals
1236 && expected.is_none()
1237 && transform_ctx.global_id.is_some()
1238 && !id.is_transient()
1239 {
1240 type_error!(
1241 false, "type warning: new non-transient global id {id}\n{}",
1243 relation.pretty()
1244 );
1245 }
1246 }
1247
1248 let got = self.typecheck(relation, &typecheck_ctx);
1249
1250 let humanizer = mz_repr::explain::DummyHumanizer;
1251
1252 match (got, expected) {
1253 (Ok(got), Some(expected)) => {
1254 let id = transform_ctx.global_id.unwrap();
1255
1256 let diffs = relation_subtype_difference(expected, &got);
1258 if !diffs.is_empty() {
1259 let severity = diffs
1261 .iter()
1262 .any(|diff| diff.clone().ignore_nullability().is_some());
1263
1264 let err = TypeError::MismatchColumns {
1265 source: relation,
1266 got,
1267 expected: expected.clone(),
1268 diffs,
1269 message: format!(
1270 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1271 ),
1272 };
1273
1274 type_error!(severity, "type error in known global id {id}:\n{err}");
1275 }
1276 }
1277 (Ok(got), None) => {
1278 if let Some(id) = transform_ctx.global_id {
1279 typecheck_ctx.insert(Id::Global(id), got);
1280 }
1281 }
1282 (Err(err), _) => {
1283 let (expected, binding) = match expected {
1284 Some(expected) => {
1285 let id = transform_ctx.global_id.unwrap();
1286 (
1287 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1288 format!("known global id {id}"),
1289 )
1290 }
1291 None => ("".to_string(), "transient query".to_string()),
1292 };
1293
1294 type_error!(
1295 true, "type error in {binding}:\n{err}\n{expected}{}",
1297 relation.pretty()
1298 );
1299 }
1300 }
1301
1302 Ok(())
1303 }
1304}
1305
1306pub fn columns_pretty<H>(cols: &[SqlColumnType], humanizer: &H) -> String
1308where
1309 H: ExprHumanizer,
1310{
1311 let mut s = String::with_capacity(2 + 3 * cols.len());
1312
1313 s.push('(');
1314
1315 let mut it = cols.iter().peekable();
1316 while let Some(col) = it.next() {
1317 s.push_str(&humanizer.humanize_column_type(col, false));
1318
1319 if it.peek().is_some() {
1320 s.push_str(", ");
1321 }
1322 }
1323
1324 s.push(')');
1325
1326 s
1327}
1328
1329impl SqlRelationTypeDifference {
1330 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1334 where
1335 H: ExprHumanizer,
1336 {
1337 use SqlRelationTypeDifference::*;
1338 match self {
1339 Length { len_sub, len_sup } => {
1340 writeln!(
1341 f,
1342 " number of columns do not match ({len_sub} != {len_sup})"
1343 )
1344 }
1345 Column { col, diff } => {
1346 writeln!(f, " column {col} differs:")?;
1347 diff.humanize(4, h, f)
1348 }
1349 }
1350 }
1351}
1352
1353impl SqlColumnTypeDifference {
1354 pub fn humanize<H>(
1356 &self,
1357 indent: usize,
1358 h: &H,
1359 f: &mut std::fmt::Formatter<'_>,
1360 ) -> std::fmt::Result
1361 where
1362 H: ExprHumanizer,
1363 {
1364 use SqlColumnTypeDifference::*;
1365
1366 write!(f, "{:indent$}", "")?;
1368
1369 match self {
1370 NotSubtype { sub, sup } => {
1371 let sub = h.humanize_scalar_type(sub, false);
1372 let sup = h.humanize_scalar_type(sup, false);
1373
1374 writeln!(f, "{sub} is a not a subtype of {sup}")
1375 }
1376 Nullability { sub, sup } => {
1377 let sub = h.humanize_column_type(sub, false);
1378 let sup = h.humanize_column_type(sup, false);
1379
1380 writeln!(f, "{sub} is nullable but {sup} is not")
1381 }
1382 ElementType { ctor, element_type } => {
1383 writeln!(f, "{ctor} element types differ:")?;
1384
1385 element_type.humanize(indent + 2, h, f)
1386 }
1387 RecordMissingFields { missing } => {
1388 write!(f, "missing column fields:")?;
1389 for col in missing {
1390 write!(f, " {col}")?;
1391 }
1392 f.write_char('\n')
1393 }
1394 RecordFields { fields } => {
1395 writeln!(f, "{} record fields differ:", fields.len())?;
1396
1397 for (col, diff) in fields {
1398 writeln!(f, "{:indent$} field '{col}':", "")?;
1399 diff.humanize(indent + 4, h, f)?;
1400 }
1401 Ok(())
1402 }
1403 }
1404 }
1405}
1406
1407#[allow(missing_debug_implementations)]
1409pub struct TypeErrorHumanizer<'a, 'b, H>
1410where
1411 H: ExprHumanizer,
1412{
1413 err: &'a TypeError<'a>,
1414 humanizer: &'b H,
1415}
1416
1417impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1418where
1419 H: ExprHumanizer,
1420{
1421 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1423 Self { err, humanizer }
1424 }
1425}
1426
1427impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1428where
1429 H: ExprHumanizer,
1430{
1431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1432 self.err.humanize(self.humanizer, f)
1433 }
1434}
1435
1436impl<'a> std::fmt::Display for TypeError<'a> {
1437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1438 TypeErrorHumanizer {
1439 err: self,
1440 humanizer: &DummyHumanizer,
1441 }
1442 .fmt(f)
1443 }
1444}
1445
1446impl<'a> TypeError<'a> {
1447 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1449 use TypeError::*;
1450 match self {
1451 Unbound { source, .. }
1452 | NoSuchColumn { source, .. }
1453 | MismatchColumn { source, .. }
1454 | MismatchColumns { source, .. }
1455 | BadConstantRow { source, .. }
1456 | BadProject { source, .. }
1457 | BadJoinEquivalence { source, .. }
1458 | BadTopKGroupKey { source, .. }
1459 | BadTopKOrdering { source, .. }
1460 | BadLetRecBindings { source }
1461 | Shadowing { source, .. }
1462 | DisallowedDummy { source, .. } => Some(source),
1463 Recursion { .. } => None,
1464 }
1465 }
1466
1467 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1468 where
1469 H: ExprHumanizer,
1470 {
1471 if let Some(source) = self.source() {
1472 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1473 }
1474
1475 use TypeError::*;
1476 match self {
1477 Unbound { source: _, id, typ } => {
1478 let typ = columns_pretty(&typ.column_types, humanizer);
1479 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1480 }
1481 NoSuchColumn {
1482 source: _,
1483 expr,
1484 col,
1485 } => writeln!(f, "{expr} references non-existent column {col}")?,
1486 MismatchColumn {
1487 source: _,
1488 got,
1489 expected,
1490 diffs,
1491 message,
1492 } => {
1493 let got = humanizer.humanize_column_type(got, false);
1494 let expected = humanizer.humanize_column_type(expected, false);
1495 writeln!(
1496 f,
1497 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1498 )?;
1499
1500 for diff in diffs {
1501 diff.humanize(2, humanizer, f)?;
1502 }
1503 }
1504 MismatchColumns {
1505 source: _,
1506 got,
1507 expected,
1508 diffs,
1509 message,
1510 } => {
1511 let got = columns_pretty(got, humanizer);
1512 let expected = columns_pretty(expected, humanizer);
1513
1514 writeln!(
1515 f,
1516 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1517 )?;
1518
1519 for diff in diffs {
1520 diff.humanize(humanizer, f)?;
1521 }
1522 }
1523 BadConstantRow {
1524 source: _,
1525 got,
1526 expected,
1527 } => {
1528 let expected = columns_pretty(expected, humanizer);
1529
1530 writeln!(
1531 f,
1532 "bad constant row\n got {got}\nexpected row of type {expected}"
1533 )?
1534 }
1535 BadProject {
1536 source: _,
1537 got,
1538 input_type,
1539 } => {
1540 let input_type = columns_pretty(input_type, humanizer);
1541
1542 writeln!(
1543 f,
1544 "projection of non-existant columns {got:?} from type {input_type}"
1545 )?
1546 }
1547 BadJoinEquivalence {
1548 source: _,
1549 got,
1550 message,
1551 } => {
1552 let got = columns_pretty(got, humanizer);
1553
1554 writeln!(f, "bad join equivalence {got}: {message}")?
1555 }
1556 BadTopKGroupKey {
1557 source: _,
1558 k,
1559 input_type,
1560 } => {
1561 let input_type = columns_pretty(input_type, humanizer);
1562
1563 writeln!(
1564 f,
1565 "TopK group key component references invalid column {k} in columns: {input_type}"
1566 )?
1567 }
1568 BadTopKOrdering {
1569 source: _,
1570 order,
1571 input_type,
1572 } => {
1573 let col = order.column;
1574 let num_cols = input_type.len();
1575 let are = if num_cols == 1 { "is" } else { "are" };
1576 let s = if num_cols == 1 { "" } else { "s" };
1577 let input_type = columns_pretty(input_type, humanizer);
1578
1579 let mode = HumanizedExplain::new(false);
1581 let order = mode.expr(order, None);
1582
1583 writeln!(
1584 f,
1585 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
1586 )?
1587 }
1588 BadLetRecBindings { source: _ } => {
1589 writeln!(f, "LetRec ids and definitions don't line up")?
1590 }
1591 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
1592 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
1593 Recursion { error } => writeln!(f, "{error}")?,
1594 }
1595
1596 Ok(())
1597 }
1598}