1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27 ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28};
29
30pub type SharedTypecheckingContext = Arc<Mutex<Context>>;
35
36pub fn empty_typechecking_context() -> SharedTypecheckingContext {
38 Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47 Unbound {
49 source: &'a MirRelationExpr,
51 id: Id,
53 typ: ReprRelationType,
55 },
56 NoSuchColumn {
58 source: &'a MirRelationExpr,
60 expr: &'a MirScalarExpr,
62 col: usize,
64 },
65 MismatchColumn {
67 source: &'a MirRelationExpr,
69 got: ReprColumnType,
71 expected: ReprColumnType,
73 diffs: Vec<ReprColumnTypeDifference>,
75 message: String,
77 },
78 MismatchColumns {
80 source: &'a MirRelationExpr,
82 got: Vec<ReprColumnType>,
84 expected: Vec<ReprColumnType>,
86 diffs: Vec<ReprRelationTypeDifference>,
88 message: String,
90 },
91 BadConstantRowLen {
93 source: &'a MirRelationExpr,
95 got: usize,
97 expected: Vec<ReprColumnType>,
99 },
100 BadConstantRow {
102 source: &'a MirRelationExpr,
104 mismatches: Vec<(usize, DatumTypeDifference)>,
106 expected: Vec<ReprColumnType>,
108 },
111 BadProject {
113 source: &'a MirRelationExpr,
115 got: Vec<usize>,
117 input_type: Vec<ReprColumnType>,
119 },
120 BadJoinEquivalence {
122 source: &'a MirRelationExpr,
124 got: Vec<ReprColumnType>,
126 message: String,
128 },
129 BadTopKGroupKey {
131 source: &'a MirRelationExpr,
133 k: usize,
135 input_type: Vec<ReprColumnType>,
137 },
138 BadTopKOrdering {
140 source: &'a MirRelationExpr,
142 order: ColumnOrder,
144 input_type: Vec<ReprColumnType>,
146 },
147 BadLetRecBindings {
149 source: &'a MirRelationExpr,
151 },
152 Shadowing {
154 source: &'a MirRelationExpr,
156 id: Id,
158 },
159 Recursion {
161 error: RecursionLimitError,
163 },
164 DisallowedDummy {
166 source: &'a MirRelationExpr,
168 },
169}
170
171impl<'a> From<RecursionLimitError> for TypeError<'a> {
172 fn from(error: RecursionLimitError) -> Self {
173 TypeError::Recursion { error }
174 }
175}
176
177type Context = BTreeMap<Id, Vec<ReprColumnType>>;
178
179#[derive(Clone, Debug, Hash)]
183pub enum ReprRelationTypeDifference {
184 Length {
186 len_sub: usize,
188 len_sup: usize,
190 },
191 Column {
193 col: usize,
195 diff: ReprColumnTypeDifference,
197 },
198}
199
200#[derive(Clone, Debug, Hash)]
205pub enum ReprColumnTypeDifference {
206 NotSubtype {
208 sub: ReprScalarType,
210 sup: ReprScalarType,
212 },
213 Nullability {
215 sub: ReprColumnType,
217 sup: ReprColumnType,
219 },
220 ElementType {
222 ctor: String,
224 element_type: Box<ReprColumnTypeDifference>,
226 },
227 RecordMissingFields {
229 missing: Vec<ColumnName>,
231 },
232 RecordFields {
234 fields: Vec<ReprColumnTypeDifference>,
236 },
237}
238
239impl ReprRelationTypeDifference {
240 pub fn ignore_nullability(self) -> Option<Self> {
244 use ReprRelationTypeDifference::*;
245
246 match self {
247 Length { .. } => Some(self),
248 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
249 }
250 }
251}
252
253impl ReprColumnTypeDifference {
254 pub fn ignore_nullability(self) -> Option<Self> {
258 use ReprColumnTypeDifference::*;
259
260 match self {
261 Nullability { .. } => None,
262 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
263 ElementType { ctor, element_type } => {
264 element_type
265 .ignore_nullability()
266 .map(|element_type| ElementType {
267 ctor,
268 element_type: Box::new(element_type),
269 })
270 }
271 RecordFields { fields } => {
272 let fields = fields
273 .into_iter()
274 .flat_map(|diff| diff.ignore_nullability())
275 .collect::<Vec<_>>();
276
277 if fields.is_empty() {
278 None
279 } else {
280 Some(RecordFields { fields })
281 }
282 }
283 }
284 }
285}
286
287pub fn relation_subtype_difference(
291 sub: &[ReprColumnType],
292 sup: &[ReprColumnType],
293) -> Vec<ReprRelationTypeDifference> {
294 let mut diffs = Vec::new();
295
296 if sub.len() != sup.len() {
297 diffs.push(ReprRelationTypeDifference::Length {
298 len_sub: sub.len(),
299 len_sup: sup.len(),
300 });
301
302 return diffs;
304 }
305
306 diffs.extend(
307 sub.iter()
308 .zip_eq(sup.iter())
309 .enumerate()
310 .flat_map(|(col, (sub_ty, sup_ty))| {
311 column_subtype_difference(sub_ty, sup_ty)
312 .into_iter()
313 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
314 }),
315 );
316
317 diffs
318}
319
320pub fn column_subtype_difference(
324 sub: &ReprColumnType,
325 sup: &ReprColumnType,
326) -> Vec<ReprColumnTypeDifference> {
327 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
328
329 if sub.nullable && !sup.nullable {
330 diffs.push(ReprColumnTypeDifference::Nullability {
331 sub: sub.clone(),
332 sup: sup.clone(),
333 });
334 }
335
336 diffs
337}
338
339pub fn scalar_subtype_difference(
343 sub: &ReprScalarType,
344 sup: &ReprScalarType,
345) -> Vec<ReprColumnTypeDifference> {
346 use ReprScalarType::*;
347
348 let mut diffs = Vec::new();
349
350 match (sub, sup) {
351 (
352 List {
353 element_type: sub_elt,
354 ..
355 },
356 List {
357 element_type: sup_elt,
358 ..
359 },
360 )
361 | (
362 Map {
363 value_type: sub_elt,
364 ..
365 },
366 Map {
367 value_type: sup_elt,
368 ..
369 },
370 )
371 | (
372 Range {
373 element_type: sub_elt,
374 ..
375 },
376 Range {
377 element_type: sup_elt,
378 ..
379 },
380 )
381 | (Array(sub_elt), Array(sup_elt)) => {
382 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
383 diffs.extend(
384 scalar_subtype_difference(sub_elt, sup_elt)
385 .into_iter()
386 .map(|diff| ReprColumnTypeDifference::ElementType {
387 ctor: ctor.clone(),
388 element_type: Box::new(diff),
389 }),
390 );
391 }
392 (
393 Record {
394 fields: sub_fields, ..
395 },
396 Record {
397 fields: sup_fields, ..
398 },
399 ) => {
400 if sub_fields.len() != sup_fields.len() {
401 diffs.push(ReprColumnTypeDifference::NotSubtype {
402 sub: sub.clone(),
403 sup: sup.clone(),
404 });
405 return diffs;
406 }
407
408 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
409 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
410 }
411 }
412 (_, _) => {
413 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
414 diffs.push(ReprColumnTypeDifference::NotSubtype {
415 sub: sub.clone(),
416 sup: sup.clone(),
417 })
418 }
419 }
420 };
421
422 diffs
423}
424
425pub fn scalar_union(
429 typ: &mut ReprScalarType,
430 other: &ReprScalarType,
431) -> Vec<ReprColumnTypeDifference> {
432 use ReprScalarType::*;
433
434 let mut diffs = Vec::new();
435
436 let ctor = ReprScalarBaseType::from(&*typ);
438 match (typ, other) {
439 (
440 List {
441 element_type: typ_elt,
442 },
443 List {
444 element_type: other_elt,
445 },
446 )
447 | (
448 Map {
449 value_type: typ_elt,
450 },
451 Map {
452 value_type: other_elt,
453 },
454 )
455 | (
456 Range {
457 element_type: typ_elt,
458 },
459 Range {
460 element_type: other_elt,
461 },
462 )
463 | (Array(typ_elt), Array(other_elt)) => {
464 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
465 diffs.extend(
466 res.into_iter()
467 .map(|diff| ReprColumnTypeDifference::ElementType {
468 ctor: format!("{ctor:?}"),
469 element_type: Box::new(diff),
470 }),
471 );
472 }
473 (
474 Record { fields: typ_fields },
475 Record {
476 fields: other_fields,
477 },
478 ) => {
479 if typ_fields.len() != other_fields.len() {
480 diffs.push(ReprColumnTypeDifference::NotSubtype {
481 sub: ReprScalarType::Record {
482 fields: typ_fields.clone(),
483 },
484 sup: other.clone(),
485 });
486 return diffs;
487 }
488
489 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
490 diffs.extend(column_union(typ_ty, other_ty));
491 }
492 }
493 (typ, _) => {
494 if ctor != ReprScalarBaseType::from(other) {
495 diffs.push(ReprColumnTypeDifference::NotSubtype {
496 sub: typ.clone(),
497 sup: other.clone(),
498 })
499 }
500 }
501 };
502
503 diffs
504}
505
506pub fn column_union(
510 typ: &mut ReprColumnType,
511 other: &ReprColumnType,
512) -> Vec<ReprColumnTypeDifference> {
513 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
514
515 if diffs.is_empty() {
516 typ.nullable |= other.nullable;
517 }
518
519 diffs
520}
521
522pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
527 if sub.len() != sup.len() {
528 return false;
529 }
530
531 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
532 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
533 })
534}
535
536#[derive(Clone, Debug)]
538pub enum DatumTypeDifference {
539 Null {
541 expected: ReprScalarType,
543 },
544 Mismatch {
546 got_debug: String,
549 expected: ReprScalarType,
551 },
552 MismatchDimensions {
554 ctor: String,
556 got: usize,
558 expected: usize,
560 },
561 ElementType {
563 ctor: String,
565 element_type: Box<DatumTypeDifference>,
567 },
568}
569
570fn datum_difference_with_column_type(
576 datum: &Datum<'_>,
577 column_type: &ReprColumnType,
578) -> Result<(), DatumTypeDifference> {
579 fn difference_with_scalar_type(
580 datum: &Datum<'_>,
581 scalar_type: &ReprScalarType,
582 ) -> Result<(), DatumTypeDifference> {
583 fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
584 Err(DatumTypeDifference::Mismatch {
585 got_debug: format!("{got:?}"),
587 expected: expected.clone(),
588 })
589 }
590
591 if let ReprScalarType::Jsonb = scalar_type {
592 match datum {
594 Datum::Dummy => Ok(()), Datum::Null => Err(DatumTypeDifference::Null {
596 expected: ReprScalarType::Jsonb,
597 }),
598 Datum::JsonNull
599 | Datum::False
600 | Datum::True
601 | Datum::Numeric(_)
602 | Datum::String(_) => Ok(()),
603 Datum::List(list) => {
604 for elem in list.iter() {
605 difference_with_scalar_type(&elem, scalar_type)?;
606 }
607 Ok(())
608 }
609 Datum::Map(dict) => {
610 for (_, val) in dict.iter() {
611 difference_with_scalar_type(&val, scalar_type)?;
612 }
613 Ok(())
614 }
615 _ => mismatch(datum, scalar_type),
616 }
617 } else {
618 fn element_type_difference(
619 ctor: &str,
620 element_type: DatumTypeDifference,
621 ) -> DatumTypeDifference {
622 DatumTypeDifference::ElementType {
623 ctor: ctor.to_string(),
624 element_type: Box::new(element_type),
625 }
626 }
627 match (datum, scalar_type) {
628 (Datum::Dummy, _) => Ok(()), (Datum::Null, _) => Err(DatumTypeDifference::Null {
630 expected: scalar_type.clone(),
631 }),
632 (Datum::False, ReprScalarType::Bool) => Ok(()),
633 (Datum::False, _) => mismatch(datum, scalar_type),
634 (Datum::True, ReprScalarType::Bool) => Ok(()),
635 (Datum::True, _) => mismatch(datum, scalar_type),
636 (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
637 (Datum::Int16(_), _) => mismatch(datum, scalar_type),
638 (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
639 (Datum::Int32(_), _) => mismatch(datum, scalar_type),
640 (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
641 (Datum::Int64(_), _) => mismatch(datum, scalar_type),
642 (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
643 (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
644 (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
645 (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
646 (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
647 (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
648 (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
649 (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
650 (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
651 (Datum::Float32(_), _) => mismatch(datum, scalar_type),
652 (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
653 (Datum::Float64(_), _) => mismatch(datum, scalar_type),
654 (Datum::Date(_), ReprScalarType::Date) => Ok(()),
655 (Datum::Date(_), _) => mismatch(datum, scalar_type),
656 (Datum::Time(_), ReprScalarType::Time) => Ok(()),
657 (Datum::Time(_), _) => mismatch(datum, scalar_type),
658 (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
659 (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
660 (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
661 (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
662 (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
663 (Datum::Interval(_), _) => mismatch(datum, scalar_type),
664 (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
665 (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
666 (Datum::String(_), ReprScalarType::String) => Ok(()),
667 (Datum::String(_), _) => mismatch(datum, scalar_type),
668 (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
669 (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
670 (Datum::Array(array), ReprScalarType::Array(t)) => {
671 for e in array.elements().iter() {
672 if let Datum::Null = e {
673 continue;
674 }
675
676 difference_with_scalar_type(&e, t)
677 .map_err(|e| element_type_difference("array", e))?;
678 }
679 Ok(())
680 }
681 (Datum::Array(array), ReprScalarType::Int2Vector) => {
682 if !array.has_int2vector_dims() {
683 return Err(DatumTypeDifference::MismatchDimensions {
686 ctor: "int2vector".to_string(),
687 got: array.dims().len(),
688 expected: 1,
689 });
690 }
691
692 for e in array.elements().iter() {
693 difference_with_scalar_type(&e, &ReprScalarType::Int16)
694 .map_err(|e| element_type_difference("int2vector", e))?;
695 }
696
697 Ok(())
698 }
699 (Datum::Array(_), _) => mismatch(datum, scalar_type),
700 (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
701 for e in list.iter() {
702 if let Datum::Null = e {
703 continue;
704 }
705
706 difference_with_scalar_type(&e, element_type)
707 .map_err(|e| element_type_difference("list", e))?;
708 }
709 Ok(())
710 }
711 (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
712 let len = list.iter().count();
713 if len != fields.len() {
714 return Err(DatumTypeDifference::MismatchDimensions {
715 ctor: "record".to_string(),
716 got: len,
717 expected: fields.len(),
718 });
719 }
720
721 for (e, t) in list.iter().zip_eq(fields) {
722 if let Datum::Null = e {
723 if t.nullable {
724 continue;
725 } else {
726 return Err(DatumTypeDifference::Null {
727 expected: t.scalar_type.clone(),
728 });
729 }
730 }
731
732 difference_with_scalar_type(&e, &t.scalar_type)
733 .map_err(|e| element_type_difference("record", e))?;
734 }
735 Ok(())
736 }
737 (Datum::List(_), _) => mismatch(datum, scalar_type),
738 (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
739 for (_, v) in map.iter() {
740 if let Datum::Null = v {
741 continue;
742 }
743
744 difference_with_scalar_type(&v, value_type)
745 .map_err(|e| element_type_difference("map", e))?;
746 }
747 Ok(())
748 }
749 (Datum::Map(_), _) => mismatch(datum, scalar_type),
750 (Datum::JsonNull, _) => mismatch(datum, scalar_type),
751 (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
752 (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
753 (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
754 (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
755 (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
756 match inner {
757 None => Ok(()),
758 Some(inner) => {
759 if let Some(b) = inner.lower.bound {
760 difference_with_scalar_type(&b.datum(), element_type)
761 .map_err(|e| element_type_difference("range", e))?;
762 }
763 if let Some(b) = inner.upper.bound {
764 difference_with_scalar_type(&b.datum(), element_type)
765 .map_err(|e| element_type_difference("range", e))?;
766 }
767 Ok(())
768 }
769 }
770 }
771 (Datum::Range(_), _) => mismatch(datum, scalar_type),
772 (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
773 (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
774 (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
775 (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
776 }
777 }
778 }
779 if column_type.nullable {
780 if let Datum::Null = datum {
781 return Ok(());
782 }
783 }
784 difference_with_scalar_type(datum, &column_type.scalar_type)
785}
786
787fn row_difference_with_column_types<'a>(
788 source: &'a MirRelationExpr,
789 datums: &Vec<Datum<'_>>,
790 column_types: &[ReprColumnType],
791) -> Result<(), TypeError<'a>> {
792 if datums.len() != column_types.len() {
794 return Err(TypeError::BadConstantRowLen {
795 source,
796 got: datums.len(),
797 expected: column_types.to_vec(),
798 });
799 }
800
801 let mut mismatches = Vec::new();
803 for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
804 if let Err(e) = datum_difference_with_column_type(d, ty) {
805 mismatches.push((i, e));
806 }
807 }
808 if !mismatches.is_empty() {
809 return Err(TypeError::BadConstantRow {
810 source,
811 mismatches,
812 expected: column_types.to_vec(),
813 });
814 }
815
816 Ok(())
817}
818#[derive(Debug)]
820pub struct Typecheck {
821 ctx: SharedTypecheckingContext,
823 disallow_new_globals: bool,
825 strict_join_equivalences: bool,
827 disallow_dummy: bool,
829 recursion_guard: RecursionGuard,
831}
832
833impl CheckedRecursion for Typecheck {
834 fn recursion_guard(&self) -> &RecursionGuard {
835 &self.recursion_guard
836 }
837}
838
839impl Typecheck {
840 pub fn new(ctx: SharedTypecheckingContext) -> Self {
842 Self {
843 ctx,
844 disallow_new_globals: false,
845 strict_join_equivalences: false,
846 disallow_dummy: false,
847 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
848 }
849 }
850
851 pub fn disallow_new_globals(mut self) -> Self {
855 self.disallow_new_globals = true;
856 self
857 }
858
859 pub fn strict_join_equivalences(mut self) -> Self {
863 self.strict_join_equivalences = true;
864
865 self
866 }
867
868 pub fn disallow_dummy(mut self) -> Self {
870 self.disallow_dummy = true;
871 self
872 }
873
874 pub fn typecheck<'a>(
885 &self,
886 expr: &'a MirRelationExpr,
887 ctx: &Context,
888 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
889 use MirRelationExpr::*;
890
891 self.checked_recur(|tc| match expr {
892 Constant { typ, rows } => {
893 if let Ok(rows) = rows {
894 for (row, _id) in rows {
895 let datums = row.unpack();
896
897 let col_types = typ
898 .column_types
899 .iter()
900 .cloned()
901 .collect_vec();
902 row_difference_with_column_types(
903 expr, &datums, &col_types,
904 )?;
905
906 if self.disallow_dummy
907 && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
908 {
909 return Err(TypeError::DisallowedDummy {
910 source: expr,
911 });
912 }
913 }
914 }
915
916 Ok(typ.column_types.iter().cloned().collect_vec())
917 }
918 Get { typ, id, .. } => {
919 if let Id::Global(_global_id) = id {
920 if !ctx.contains_key(id) {
921 return Ok(typ.column_types.iter().cloned().collect_vec());
923 }
924 }
925
926 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
927 source: expr,
928 id: id.clone(),
929 typ: typ.clone(),
930 })?;
931
932 let column_types = typ.column_types.iter().cloned().collect_vec();
933
934 let diffs = relation_subtype_difference(&column_types, ctx_typ)
936 .into_iter()
937 .flat_map(|diff| diff.ignore_nullability())
938 .collect::<Vec<_>>();
939
940 if !diffs.is_empty() {
941 return Err(TypeError::MismatchColumns {
942 source: expr,
943 got: column_types,
944 expected: ctx_typ.clone(),
945 diffs,
946 message: "annotation did not match context type".to_string(),
947 });
948 }
949
950 Ok(column_types)
951 }
952 Project { input, outputs } => {
953 let t_in = tc.typecheck(input, ctx)?;
954
955 for x in outputs {
956 if *x >= t_in.len() {
957 return Err(TypeError::BadProject {
958 source: expr,
959 got: outputs.clone(),
960 input_type: t_in,
961 });
962 }
963 }
964
965 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
966 }
967 Map { input, scalars } => {
968 let mut t_in = tc.typecheck(input, ctx)?;
969
970 for scalar_expr in scalars.iter() {
971 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
972
973 if self.disallow_dummy && scalar_expr.contains_dummy() {
974 return Err(TypeError::DisallowedDummy {
975 source: expr,
976 });
977 }
978 }
979
980 Ok(t_in)
981 }
982 FlatMap { input, func, exprs } => {
983 let mut t_in = tc.typecheck(input, ctx)?;
984
985 let mut t_exprs = Vec::with_capacity(exprs.len());
986 for scalar_expr in exprs {
987 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
988
989 if self.disallow_dummy && scalar_expr.contains_dummy() {
990 return Err(TypeError::DisallowedDummy {
991 source: expr,
992 });
993 }
994 }
995 let t_out: Vec<ReprColumnType> = func
998 .output_type().column_types;
999
1000 t_in.extend(t_out);
1002 Ok(t_in)
1003 }
1004 Filter { input, predicates } => {
1005 let mut t_in = tc.typecheck(input, ctx)?;
1006
1007 for column in non_nullable_columns(predicates) {
1010 t_in[column].nullable = false;
1011 }
1012
1013 for scalar_expr in predicates {
1014 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1015
1016 if t.scalar_type != ReprScalarType::Bool {
1020 let sub = t.scalar_type.clone();
1021
1022 return Err(TypeError::MismatchColumn {
1023 source: expr,
1024 got: t,
1025 expected: ReprColumnType {
1026 scalar_type: ReprScalarType::Bool,
1027 nullable: true,
1028 },
1029 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1030 sub,
1031 sup: ReprScalarType::Bool,
1032 }],
1033 message: "expected boolean condition".to_string(),
1034 });
1035 }
1036
1037 if self.disallow_dummy && scalar_expr.contains_dummy() {
1038 return Err(TypeError::DisallowedDummy {
1039 source: expr,
1040 });
1041 }
1042 }
1043
1044 Ok(t_in)
1045 }
1046 Join {
1047 inputs,
1048 equivalences,
1049 implementation,
1050 } => {
1051 let mut t_in_global = Vec::new();
1052 let mut t_in_local = vec![Vec::new(); inputs.len()];
1053
1054 for (i, input) in inputs.iter().enumerate() {
1055 let input_t = tc.typecheck(input, ctx)?;
1056 t_in_global.extend(input_t.clone());
1057 t_in_local[i] = input_t;
1058 }
1059
1060 for eq_class in equivalences {
1061 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1062
1063 let mut all_nullable = true;
1064
1065 for scalar_expr in eq_class {
1066 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1068
1069 if !t_expr.nullable {
1070 all_nullable = false;
1071 }
1072
1073 if let Some(t_first) = t_exprs.get(0) {
1074 let diffs = scalar_subtype_difference(
1075 &t_expr.scalar_type,
1076 &t_first.scalar_type,
1077 );
1078 if !diffs.is_empty() {
1079 return Err(TypeError::MismatchColumn {
1080 source: expr,
1081 got: t_expr,
1082 expected: t_first.clone(),
1083 diffs,
1084 message: "equivalence class members \
1085 have different scalar types"
1086 .to_string(),
1087 });
1088 }
1089
1090 if self.strict_join_equivalences {
1094 if t_expr.nullable != t_first.nullable {
1095 let sub = t_expr.clone();
1096 let sup = t_first.clone();
1097
1098 let err = TypeError::MismatchColumn {
1099 source: expr,
1100 got: t_expr.clone(),
1101 expected: t_first.clone(),
1102 diffs: vec![
1103 ReprColumnTypeDifference::Nullability { sub, sup },
1104 ],
1105 message: "equivalence class members have \
1106 different nullability (and join \
1107 equivalence checking is strict)"
1108 .to_string(),
1109 };
1110
1111 ::tracing::debug!("{err}");
1113 }
1114 }
1115 }
1116
1117 if self.disallow_dummy && scalar_expr.contains_dummy() {
1118 return Err(TypeError::DisallowedDummy {
1119 source: expr,
1120 });
1121 }
1122
1123 t_exprs.push(t_expr);
1124 }
1125
1126 if self.strict_join_equivalences && all_nullable {
1127 let err = TypeError::BadJoinEquivalence {
1128 source: expr,
1129 got: t_exprs,
1130 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1131 };
1132
1133 ::tracing::debug!("{err}");
1135 }
1136 }
1137
1138 match implementation {
1140 JoinImplementation::Differential((start_idx, first_key, _), others) => {
1141 if let Some(key) = first_key {
1142 for k in key {
1143 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1144 }
1145 }
1146
1147 for (idx, key, _) in others {
1148 for k in key {
1149 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1150 }
1151 }
1152 }
1153 JoinImplementation::DeltaQuery(plans) => {
1154 for plan in plans {
1155 for (idx, key, _) in plan {
1156 for k in key {
1157 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1158 }
1159 }
1160 }
1161 }
1162 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1163 let typ: Vec<ReprColumnType> = key
1164 .iter()
1165 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1166 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1167
1168 for row in consts {
1169 let datums = row.unpack();
1170
1171 row_difference_with_column_types(expr, &datums, &typ)?;
1172 }
1173 }
1174 JoinImplementation::Unimplemented => (),
1175 }
1176
1177 Ok(t_in_global)
1178 }
1179 Reduce {
1180 input,
1181 group_key,
1182 aggregates,
1183 monotonic: _,
1184 expected_group_size: _,
1185 } => {
1186 let t_in = tc.typecheck(input, ctx)?;
1187
1188 let mut t_out = group_key
1189 .iter()
1190 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1191 .collect::<Result<Vec<_>, _>>()?;
1192
1193 if self.disallow_dummy
1194 && group_key
1195 .iter()
1196 .any(|scalar_expr| scalar_expr.contains_dummy())
1197 {
1198 return Err(TypeError::DisallowedDummy {
1199 source: expr,
1200 });
1201 }
1202
1203 for agg in aggregates {
1204 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1205 }
1206
1207 Ok(t_out)
1208 }
1209 TopK {
1210 input,
1211 group_key,
1212 order_key,
1213 limit: _,
1214 offset: _,
1215 monotonic: _,
1216 expected_group_size: _,
1217 } => {
1218 let t_in = tc.typecheck(input, ctx)?;
1219
1220 for &k in group_key {
1221 if k >= t_in.len() {
1222 return Err(TypeError::BadTopKGroupKey {
1223 source: expr,
1224 k,
1225 input_type: t_in,
1226 });
1227 }
1228 }
1229
1230 for order in order_key {
1231 if order.column >= t_in.len() {
1232 return Err(TypeError::BadTopKOrdering {
1233 source: expr,
1234 order: order.clone(),
1235 input_type: t_in,
1236 });
1237 }
1238 }
1239
1240 Ok(t_in)
1241 }
1242 Negate { input } => tc.typecheck(input, ctx),
1243 Threshold { input } => tc.typecheck(input, ctx),
1244 Union { base, inputs } => {
1245 let mut t_base = tc.typecheck(base, ctx)?;
1246
1247 for input in inputs {
1248 let t_input = tc.typecheck(input, ctx)?;
1249
1250 let len_sub = t_base.len();
1251 let len_sup = t_input.len();
1252 if len_sub != len_sup {
1253 return Err(TypeError::MismatchColumns {
1254 source: expr,
1255 got: t_base.clone(),
1256 expected: t_input,
1257 diffs: vec![ReprRelationTypeDifference::Length {
1258 len_sub,
1259 len_sup,
1260 }],
1261 message: "Union branches have different numbers of columns".to_string(),
1262 });
1263 }
1264
1265 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1266 let diffs = column_union(base_col, &input_col);
1267 if !diffs.is_empty() {
1268 return Err(TypeError::MismatchColumn {
1269 source: expr,
1270 got: input_col,
1271 expected: base_col.clone(),
1272 diffs,
1273 message:
1274 "couldn't compute union of column types in Union"
1275 .to_string(),
1276 });
1277 }
1278
1279 }
1280 }
1281
1282 Ok(t_base)
1283 }
1284 Let { id, value, body } => {
1285 let t_value = tc.typecheck(value, ctx)?;
1286
1287 let binding = Id::Local(*id);
1288 if ctx.contains_key(&binding) {
1289 return Err(TypeError::Shadowing {
1290 source: expr,
1291 id: binding,
1292 });
1293 }
1294
1295 let mut body_ctx = ctx.clone();
1296 body_ctx.insert(Id::Local(*id), t_value);
1297
1298 tc.typecheck(body, &body_ctx)
1299 }
1300 LetRec { ids, values, body, limits: _ } => {
1301 if ids.len() != values.len() {
1302 return Err(TypeError::BadLetRecBindings { source: expr });
1303 }
1304
1305 let mut ctx = ctx.clone();
1308 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1310 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1311 }
1312
1313 for (id, value) in ids.iter().zip_eq(values.iter()) {
1314 let typ = tc.typecheck(value, &ctx)?;
1315
1316 let id = Id::Local(id.clone());
1317 if let Some(ctx_typ) = ctx.get_mut(&id) {
1318 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1319 let diffs = column_union(base_col, &input_col);
1321 if !diffs.is_empty() {
1322 return Err(TypeError::MismatchColumn {
1323 source: expr,
1324 got: input_col,
1325 expected: base_col.clone(),
1326 diffs,
1327 message:
1328 "couldn't compute union of column types in LetRec"
1329 .to_string(),
1330 })
1331 }
1332 }
1333 } else {
1334 ctx.insert(id, typ);
1336 }
1337 }
1338
1339 tc.typecheck(body, &ctx)
1340 }
1341 ArrangeBy { input, keys } => {
1342 let t_in = tc.typecheck(input, ctx)?;
1343
1344 for key in keys {
1345 for k in key {
1346 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1347 }
1348 }
1349
1350 Ok(t_in)
1351 }
1352 })
1353 }
1354
1355 fn collect_recursive_variable_types<'a>(
1359 &self,
1360 expr: &'a MirRelationExpr,
1361 ids: &[LocalId],
1362 ctx: &mut Context,
1363 ) -> Result<(), TypeError<'a>> {
1364 use MirRelationExpr::*;
1365
1366 self.checked_recur(|tc| {
1367 match expr {
1368 Get {
1369 id: Id::Local(id),
1370 typ,
1371 ..
1372 } => {
1373 if !ids.contains(id) {
1374 return Ok(());
1375 }
1376
1377 let id = Id::Local(id.clone());
1378 if let Some(ctx_typ) = ctx.get_mut(&id) {
1379 let typ = typ.column_types.iter().cloned().collect_vec();
1380
1381 if ctx_typ.len() != typ.len() {
1382 let diffs = relation_subtype_difference(&typ, ctx_typ);
1383
1384 return Err(TypeError::MismatchColumns {
1385 source: expr,
1386 got: typ,
1387 expected: ctx_typ.clone(),
1388 diffs,
1389 message: "environment and type annotation did not match"
1390 .to_string(),
1391 });
1392 }
1393
1394 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1395 let diffs = column_union(base_col, &input_col);
1396 if !diffs.is_empty() {
1397 return Err(TypeError::MismatchColumn {
1398 source: expr,
1399 got: input_col,
1400 expected: base_col.clone(),
1401 diffs,
1402 message:
1403 "couldn't compute union of column types in Get and context"
1404 .to_string(),
1405 });
1406 }
1407 }
1408 } else {
1409 ctx.insert(id, typ.column_types.iter().cloned().collect_vec());
1410 }
1411 }
1412 Get {
1413 id: Id::Global(..), ..
1414 }
1415 | Constant { .. } => (),
1416 Let { id, value, body } => {
1417 tc.collect_recursive_variable_types(value, ids, ctx)?;
1418
1419 if ids.contains(id) {
1421 return Err(TypeError::Shadowing {
1422 source: expr,
1423 id: Id::Local(*id),
1424 });
1425 }
1426
1427 tc.collect_recursive_variable_types(body, ids, ctx)?;
1428 }
1429 LetRec {
1430 ids: inner_ids,
1431 values,
1432 body,
1433 limits: _,
1434 } => {
1435 for inner_id in inner_ids {
1436 if ids.contains(inner_id) {
1437 return Err(TypeError::Shadowing {
1438 source: expr,
1439 id: Id::Local(*inner_id),
1440 });
1441 }
1442 }
1443
1444 for value in values {
1445 tc.collect_recursive_variable_types(value, ids, ctx)?;
1446 }
1447
1448 tc.collect_recursive_variable_types(body, ids, ctx)?;
1449 }
1450 Project { input, .. }
1451 | Map { input, .. }
1452 | FlatMap { input, .. }
1453 | Filter { input, .. }
1454 | Reduce { input, .. }
1455 | TopK { input, .. }
1456 | Negate { input }
1457 | Threshold { input }
1458 | ArrangeBy { input, .. } => {
1459 tc.collect_recursive_variable_types(input, ids, ctx)?;
1460 }
1461 Join { inputs, .. } => {
1462 for input in inputs {
1463 tc.collect_recursive_variable_types(input, ids, ctx)?;
1464 }
1465 }
1466 Union { base, inputs } => {
1467 tc.collect_recursive_variable_types(base, ids, ctx)?;
1468
1469 for input in inputs {
1470 tc.collect_recursive_variable_types(input, ids, ctx)?;
1471 }
1472 }
1473 }
1474
1475 Ok(())
1476 })
1477 }
1478
1479 fn typecheck_scalar<'a>(
1480 &self,
1481 expr: &'a MirScalarExpr,
1482 source: &'a MirRelationExpr,
1483 column_types: &[ReprColumnType],
1484 ) -> Result<ReprColumnType, TypeError<'a>> {
1485 use MirScalarExpr::*;
1486
1487 self.checked_recur(|tc| match expr {
1488 Column(i, _) => match column_types.get(*i) {
1489 Some(ty) => Ok(ty.clone()),
1490 None => Err(TypeError::NoSuchColumn {
1491 source,
1492 expr,
1493 col: *i,
1494 }),
1495 },
1496 Literal(row, typ) => {
1497 let typ = typ.clone();
1498 if let Ok(row) = row {
1499 let datums = row.unpack();
1500
1501 row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1502 }
1503
1504 Ok(typ)
1505 }
1506 CallUnmaterializable(func) => Ok(func.output_type()),
1507 CallUnary { expr, func } => {
1508 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1509 let typ_out = func.output_type(typ_in);
1510 Ok(typ_out)
1511 }
1512 CallBinary { expr1, expr2, func } => {
1513 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1514 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1515 let typ_out = func.output_type(&[typ_in1, typ_in2]);
1516 Ok(typ_out)
1517 }
1518 CallVariadic { exprs, func } => Ok(func.output_type(
1519 exprs
1520 .iter()
1521 .map(|e| tc.typecheck_scalar(e, source, column_types))
1522 .collect::<Result<Vec<_>, TypeError>>()?,
1523 )),
1524 If { cond, then, els } => {
1525 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1526
1527 if cond_type.scalar_type != ReprScalarType::Bool {
1531 let sub = cond_type.scalar_type.clone();
1532
1533 return Err(TypeError::MismatchColumn {
1534 source,
1535 got: cond_type,
1536 expected: ReprColumnType {
1537 scalar_type: ReprScalarType::Bool,
1538 nullable: true,
1539 },
1540 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1541 sub,
1542 sup: ReprScalarType::Bool,
1543 }],
1544 message: "expected boolean condition".to_string(),
1545 });
1546 }
1547
1548 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1549 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1550
1551 let diffs = column_union(&mut then_type, &else_type);
1552 if !diffs.is_empty() {
1553 return Err(TypeError::MismatchColumn {
1554 source,
1555 got: then_type,
1556 expected: else_type,
1557 diffs,
1558 message: "couldn't compute union of column types for If".to_string(),
1559 });
1560 }
1561
1562 Ok(then_type)
1563 }
1564 })
1565 }
1566
1567 pub fn typecheck_aggregate<'a>(
1569 &self,
1570 expr: &'a AggregateExpr,
1571 source: &'a MirRelationExpr,
1572 column_types: &[ReprColumnType],
1573 ) -> Result<ReprColumnType, TypeError<'a>> {
1574 self.checked_recur(|tc| {
1575 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1576
1577 Ok(expr.func.output_type(t_in))
1580 })
1581 }
1582}
1583
1584macro_rules! type_error {
1588 ($severity:expr, $($arg:tt)+) => {{
1589 if $severity {
1590 soft_panic_or_log!($($arg)+);
1591 } else {
1592 ::tracing::debug!($($arg)+);
1593 }
1594 }}
1595}
1596
1597impl crate::Transform for Typecheck {
1598 fn name(&self) -> &'static str {
1599 "Typecheck"
1600 }
1601
1602 fn actually_perform_transform(
1603 &self,
1604 relation: &mut MirRelationExpr,
1605 transform_ctx: &mut crate::TransformCtx,
1606 ) -> Result<(), crate::TransformError> {
1607 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1608
1609 let expected = transform_ctx
1610 .global_id
1611 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1612
1613 if let Some(id) = transform_ctx.global_id {
1614 if self.disallow_new_globals
1615 && expected.is_none()
1616 && transform_ctx.global_id.is_some()
1617 && !id.is_transient()
1618 {
1619 type_error!(
1620 false, "type warning: new non-transient global id {id}\n{}",
1622 relation.pretty()
1623 );
1624 }
1625 }
1626
1627 let got = self.typecheck(relation, &typecheck_ctx);
1628
1629 let humanizer = mz_repr::explain::DummyHumanizer;
1630
1631 match (got, expected) {
1632 (Ok(got), Some(expected)) => {
1633 let id = transform_ctx.global_id.unwrap();
1634
1635 let diffs = relation_subtype_difference(expected, &got);
1637 if !diffs.is_empty() {
1638 let severity = diffs
1640 .iter()
1641 .any(|diff| diff.clone().ignore_nullability().is_some());
1642
1643 let err = TypeError::MismatchColumns {
1644 source: relation,
1645 got,
1646 expected: expected.clone(),
1647 diffs,
1648 message: format!(
1649 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1650 ),
1651 };
1652
1653 type_error!(severity, "type error in known global id {id}:\n{err}");
1654 }
1655 }
1656 (Ok(got), None) => {
1657 if let Some(id) = transform_ctx.global_id {
1658 typecheck_ctx.insert(Id::Global(id), got);
1659 }
1660 }
1661 (Err(err), _) => {
1662 let (expected, binding) = match expected {
1663 Some(expected) => {
1664 let id = transform_ctx.global_id.unwrap();
1665 (
1666 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1667 format!("known global id {id}"),
1668 )
1669 }
1670 None => ("".to_string(), "transient query".to_string()),
1671 };
1672
1673 type_error!(
1674 true, "type error in {binding}:\n{err}\n{expected}{}",
1676 relation.pretty()
1677 );
1678 }
1679 }
1680
1681 Ok(())
1682 }
1683}
1684
1685pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1687where
1688 H: ExprHumanizer,
1689{
1690 let mut s = String::with_capacity(2 + 3 * cols.len());
1691
1692 s.push('(');
1693
1694 let mut it = cols.iter().peekable();
1695 while let Some(col) = it.next() {
1696 s.push_str(&humanizer.humanize_column_type(col, false));
1697
1698 if it.peek().is_some() {
1699 s.push_str(", ");
1700 }
1701 }
1702
1703 s.push(')');
1704
1705 s
1706}
1707
1708impl ReprRelationTypeDifference {
1709 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1713 where
1714 H: ExprHumanizer,
1715 {
1716 use ReprRelationTypeDifference::*;
1717 match self {
1718 Length { len_sub, len_sup } => {
1719 writeln!(
1720 f,
1721 " number of columns do not match ({len_sub} != {len_sup})"
1722 )
1723 }
1724 Column { col, diff } => {
1725 writeln!(f, " column {col} differs:")?;
1726 diff.humanize(4, h, f)
1727 }
1728 }
1729 }
1730}
1731
1732impl ReprColumnTypeDifference {
1733 pub fn humanize<H>(
1735 &self,
1736 indent: usize,
1737 h: &H,
1738 f: &mut std::fmt::Formatter<'_>,
1739 ) -> std::fmt::Result
1740 where
1741 H: ExprHumanizer,
1742 {
1743 use ReprColumnTypeDifference::*;
1744
1745 write!(f, "{:indent$}", "")?;
1747
1748 match self {
1749 NotSubtype { sub, sup } => {
1750 let sub = h.humanize_scalar_type(sub, false);
1751 let sup = h.humanize_scalar_type(sup, false);
1752
1753 writeln!(f, "{sub} is a not a subtype of {sup}")
1754 }
1755 Nullability { sub, sup } => {
1756 let sub = h.humanize_column_type(sub, false);
1757 let sup = h.humanize_column_type(sup, false);
1758
1759 writeln!(f, "{sub} is nullable but {sup} is not")
1760 }
1761 ElementType { ctor, element_type } => {
1762 writeln!(f, "{ctor} element types differ:")?;
1763
1764 element_type.humanize(indent + 2, h, f)
1765 }
1766 RecordMissingFields { missing } => {
1767 write!(f, "missing column fields:")?;
1768 for col in missing {
1769 write!(f, " {col}")?;
1770 }
1771 f.write_char('\n')
1772 }
1773 RecordFields { fields } => {
1774 writeln!(f, "{} record fields differ:", fields.len())?;
1775
1776 for (i, diff) in fields.iter().enumerate() {
1777 writeln!(f, "{:indent$} field {i}:", "")?;
1778 diff.humanize(indent + 4, h, f)?;
1779 }
1780 Ok(())
1781 }
1782 }
1783 }
1784}
1785
1786impl DatumTypeDifference {
1787 pub fn humanize<H>(
1789 &self,
1790 indent: usize,
1791 h: &H,
1792 f: &mut std::fmt::Formatter<'_>,
1793 ) -> std::fmt::Result
1794 where
1795 H: ExprHumanizer,
1796 {
1797 write!(f, "{:indent$}", "")?;
1799
1800 match self {
1801 DatumTypeDifference::Null { expected } => {
1802 let expected = h.humanize_scalar_type(expected, false);
1803 writeln!(
1804 f,
1805 "unexpected null, expected representation type {expected}"
1806 )?
1807 }
1808 DatumTypeDifference::Mismatch {
1809 got_debug,
1810 expected,
1811 } => {
1812 let expected = h.humanize_scalar_type(expected, false);
1813 writeln!(
1815 f,
1816 "got datum {got_debug}, expected representation type {expected}"
1817 )?;
1818 }
1819 DatumTypeDifference::MismatchDimensions {
1820 ctor,
1821 got,
1822 expected,
1823 } => {
1824 writeln!(
1825 f,
1826 "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1827 )?;
1828 }
1829 DatumTypeDifference::ElementType { ctor, element_type } => {
1830 writeln!(f, "{ctor} element types differ:")?;
1831 element_type.humanize(indent + 4, h, f)?;
1832 }
1833 }
1834
1835 Ok(())
1836 }
1837}
1838
1839#[allow(missing_debug_implementations)]
1841pub struct TypeErrorHumanizer<'a, 'b, H>
1842where
1843 H: ExprHumanizer,
1844{
1845 err: &'a TypeError<'a>,
1846 humanizer: &'b H,
1847}
1848
1849impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1850where
1851 H: ExprHumanizer,
1852{
1853 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1855 Self { err, humanizer }
1856 }
1857}
1858
1859impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1860where
1861 H: ExprHumanizer,
1862{
1863 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1864 self.err.humanize(self.humanizer, f)
1865 }
1866}
1867
1868impl<'a> std::fmt::Display for TypeError<'a> {
1869 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1870 TypeErrorHumanizer {
1871 err: self,
1872 humanizer: &DummyHumanizer,
1873 }
1874 .fmt(f)
1875 }
1876}
1877
1878impl<'a> TypeError<'a> {
1879 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1881 use TypeError::*;
1882 match self {
1883 Unbound { source, .. }
1884 | NoSuchColumn { source, .. }
1885 | MismatchColumn { source, .. }
1886 | MismatchColumns { source, .. }
1887 | BadConstantRowLen { source, .. }
1888 | BadConstantRow { source, .. }
1889 | BadProject { source, .. }
1890 | BadJoinEquivalence { source, .. }
1891 | BadTopKGroupKey { source, .. }
1892 | BadTopKOrdering { source, .. }
1893 | BadLetRecBindings { source }
1894 | Shadowing { source, .. }
1895 | DisallowedDummy { source, .. } => Some(source),
1896 Recursion { .. } => None,
1897 }
1898 }
1899
1900 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1901 where
1902 H: ExprHumanizer,
1903 {
1904 if let Some(source) = self.source() {
1905 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1906 }
1907
1908 use TypeError::*;
1909 match self {
1910 Unbound { source: _, id, typ } => {
1911 let typ = columns_pretty(&typ.column_types, humanizer);
1912 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1913 }
1914 NoSuchColumn {
1915 source: _,
1916 expr,
1917 col,
1918 } => writeln!(f, "{expr} references non-existent column {col}")?,
1919 MismatchColumn {
1920 source: _,
1921 got,
1922 expected,
1923 diffs,
1924 message,
1925 } => {
1926 let got = humanizer.humanize_column_type(got, false);
1927 let expected = humanizer.humanize_column_type(expected, false);
1928 writeln!(
1929 f,
1930 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1931 )?;
1932
1933 for diff in diffs {
1934 diff.humanize(2, humanizer, f)?;
1935 }
1936 }
1937 MismatchColumns {
1938 source: _,
1939 got,
1940 expected,
1941 diffs,
1942 message,
1943 } => {
1944 let got = columns_pretty(got, humanizer);
1945 let expected = columns_pretty(expected, humanizer);
1946
1947 writeln!(
1948 f,
1949 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1950 )?;
1951
1952 for diff in diffs {
1953 diff.humanize(humanizer, f)?;
1954 }
1955 }
1956 BadConstantRowLen {
1957 source: _,
1958 got,
1959 expected,
1960 } => {
1961 let expected = columns_pretty(expected, humanizer);
1962 writeln!(
1963 f,
1964 "bad constant row\n row has length {got}\nexpected row of type {expected}"
1965 )?
1966 }
1967 BadConstantRow {
1968 source: _,
1969 mismatches,
1970 expected,
1971 } => {
1972 let expected = columns_pretty(expected, humanizer);
1973
1974 let num_mismatches = mismatches.len();
1975 let plural = if num_mismatches == 1 { "" } else { "es" };
1976 writeln!(
1977 f,
1978 "bad constant row\n got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1979 )?;
1980
1981 if num_mismatches > 0 {
1982 writeln!(f, "")?;
1983 for (col, diff) in mismatches.iter() {
1984 writeln!(f, " column #{col}:")?;
1985 diff.humanize(8, humanizer, f)?;
1986 }
1987 }
1988 }
1989 BadProject {
1990 source: _,
1991 got,
1992 input_type,
1993 } => {
1994 let input_type = columns_pretty(input_type, humanizer);
1995
1996 writeln!(
1997 f,
1998 "projection of non-existant columns {got:?} from type {input_type}"
1999 )?
2000 }
2001 BadJoinEquivalence {
2002 source: _,
2003 got,
2004 message,
2005 } => {
2006 let got = columns_pretty(got, humanizer);
2007
2008 writeln!(f, "bad join equivalence {got}: {message}")?
2009 }
2010 BadTopKGroupKey {
2011 source: _,
2012 k,
2013 input_type,
2014 } => {
2015 let input_type = columns_pretty(input_type, humanizer);
2016
2017 writeln!(
2018 f,
2019 "TopK group key component references invalid column {k} in columns: {input_type}"
2020 )?
2021 }
2022 BadTopKOrdering {
2023 source: _,
2024 order,
2025 input_type,
2026 } => {
2027 let col = order.column;
2028 let num_cols = input_type.len();
2029 let are = if num_cols == 1 { "is" } else { "are" };
2030 let s = if num_cols == 1 { "" } else { "s" };
2031 let input_type = columns_pretty(input_type, humanizer);
2032
2033 let mode = HumanizedExplain::new(false);
2035 let order = mode.expr(order, None);
2036
2037 writeln!(
2038 f,
2039 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2040 )?
2041 }
2042 BadLetRecBindings { source: _ } => {
2043 writeln!(f, "LetRec ids and definitions don't line up")?
2044 }
2045 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2046 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2047 Recursion { error } => writeln!(f, "{error}")?,
2048 }
2049
2050 Ok(())
2051 }
2052}
2053
2054#[cfg(test)]
2055mod tests {
2056 use mz_ore::{assert_err, assert_ok};
2057 use mz_repr::{SqlColumnType, arb_datum, arb_datum_for_column};
2058 use proptest::prelude::*;
2059
2060 use super::*;
2061
2062 #[mz_ore::test]
2063 fn test_datum_type_difference() {
2064 let datum = Datum::Int16(1);
2065
2066 assert_ok!(datum_difference_with_column_type(
2067 &datum,
2068 &ReprColumnType {
2069 scalar_type: ReprScalarType::Int16,
2070 nullable: true,
2071 }
2072 ));
2073
2074 assert_err!(datum_difference_with_column_type(
2075 &datum,
2076 &ReprColumnType {
2077 scalar_type: ReprScalarType::Int32,
2078 nullable: false,
2079 }
2080 ));
2081 }
2082
2083 proptest! {
2084 #![proptest_config(ProptestConfig {
2085 cases: 5000,
2086 max_global_rejects: 2500,
2087 ..Default::default()
2088 })]
2089 #[mz_ore::test]
2090 #[cfg_attr(miri, ignore)]
2091 fn datum_type_difference_with_instance_of_on_valid_data(
2092 (src, datum) in any::<SqlColumnType>()
2093 .prop_flat_map(|src| {
2094 let datum = arb_datum_for_column(src.clone());
2095 (Just(src), datum)
2096 })
2097 ) {
2098 let typ = ReprColumnType::from(&src);
2099 let datum = Datum::from(&datum);
2100
2101 if datum.contains_dummy() {
2102 return Err(TestCaseError::reject("datum contains a dummy"));
2103 }
2104
2105 let diff = datum_difference_with_column_type(&datum, &typ);
2106 if datum.is_instance_of(&typ) {
2107 assert_ok!(diff);
2108 } else {
2109 assert_err!(diff);
2110 }
2111 }
2112 }
2113
2114 proptest! {
2115 #![proptest_config(ProptestConfig::with_cases(10000))]
2118 #[mz_ore::test]
2119 #[cfg_attr(miri, ignore)]
2120 fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2121 src in any::<SqlColumnType>(),
2122 datum in arb_datum(false),
2123 ) {
2124 let typ = ReprColumnType::from(&src);
2125 let datum = Datum::from(&datum);
2126
2127 assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2128
2129 let diff = datum_difference_with_column_type(&datum, &typ);
2130 if datum.is_instance_of(&typ) {
2131 assert_ok!(diff);
2132 } else {
2133 assert_err!(diff);
2134 }
2135 }
2136 }
2137
2138 #[mz_ore::test]
2139 fn datum_type_difference_github_10039() {
2140 let typ = ReprColumnType {
2141 scalar_type: ReprScalarType::Record {
2142 fields: Box::new([ReprColumnType {
2143 scalar_type: ReprScalarType::UInt32,
2144 nullable: false,
2145 }]),
2146 },
2147 nullable: false,
2148 };
2149
2150 let mut row = mz_repr::Row::default();
2151 row.packer()
2152 .push_list(std::iter::once(mz_repr::Datum::Null));
2153 let datum = row.unpack_first();
2154
2155 assert!(!datum.is_instance_of(&typ));
2156 let diff = datum_difference_with_column_type(&datum, &typ);
2157 assert_err!(diff);
2158 }
2159}