1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27 ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28};
29
30pub type SharedTypecheckingContext = Arc<Mutex<Context>>;
35
36pub fn empty_typechecking_context() -> SharedTypecheckingContext {
38 Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47 Unbound {
49 source: &'a MirRelationExpr,
51 id: Id,
53 typ: ReprRelationType,
55 },
56 NoSuchColumn {
58 source: &'a MirRelationExpr,
60 expr: &'a MirScalarExpr,
62 col: usize,
64 },
65 MismatchColumn {
67 source: &'a MirRelationExpr,
69 got: ReprColumnType,
71 expected: ReprColumnType,
73 diffs: Vec<ReprColumnTypeDifference>,
75 message: String,
77 },
78 MismatchColumns {
80 source: &'a MirRelationExpr,
82 got: Vec<ReprColumnType>,
84 expected: Vec<ReprColumnType>,
86 diffs: Vec<ReprRelationTypeDifference>,
88 message: String,
90 },
91 BadConstantRowLen {
93 source: &'a MirRelationExpr,
95 got: usize,
97 expected: Vec<ReprColumnType>,
99 },
100 BadConstantRow {
102 source: &'a MirRelationExpr,
104 mismatches: Vec<(usize, DatumTypeDifference)>,
106 expected: Vec<ReprColumnType>,
108 },
111 BadProject {
113 source: &'a MirRelationExpr,
115 got: Vec<usize>,
117 input_type: Vec<ReprColumnType>,
119 },
120 BadJoinEquivalence {
122 source: &'a MirRelationExpr,
124 got: Vec<ReprColumnType>,
126 message: String,
128 },
129 BadTopKGroupKey {
131 source: &'a MirRelationExpr,
133 k: usize,
135 input_type: Vec<ReprColumnType>,
137 },
138 BadTopKOrdering {
140 source: &'a MirRelationExpr,
142 order: ColumnOrder,
144 input_type: Vec<ReprColumnType>,
146 },
147 BadLetRecBindings {
149 source: &'a MirRelationExpr,
151 },
152 Shadowing {
154 source: &'a MirRelationExpr,
156 id: Id,
158 },
159 Recursion {
161 error: RecursionLimitError,
163 },
164 DisallowedDummy {
166 source: &'a MirRelationExpr,
168 },
169}
170
171impl<'a> From<RecursionLimitError> for TypeError<'a> {
172 fn from(error: RecursionLimitError) -> Self {
173 TypeError::Recursion { error }
174 }
175}
176
177type Context = BTreeMap<Id, Vec<ReprColumnType>>;
178
179#[derive(Clone, Debug, Hash)]
183pub enum ReprRelationTypeDifference {
184 Length {
186 len_sub: usize,
188 len_sup: usize,
190 },
191 Column {
193 col: usize,
195 diff: ReprColumnTypeDifference,
197 },
198}
199
200#[derive(Clone, Debug, Hash)]
205pub enum ReprColumnTypeDifference {
206 NotSubtype {
208 sub: ReprScalarType,
210 sup: ReprScalarType,
212 },
213 Nullability {
215 sub: ReprColumnType,
217 sup: ReprColumnType,
219 },
220 ElementType {
222 ctor: String,
224 element_type: Box<ReprColumnTypeDifference>,
226 },
227 RecordMissingFields {
229 missing: Vec<ColumnName>,
231 },
232 RecordFields {
234 fields: Vec<ReprColumnTypeDifference>,
236 },
237}
238
239impl ReprRelationTypeDifference {
240 pub fn ignore_nullability(self) -> Option<Self> {
244 use ReprRelationTypeDifference::*;
245
246 match self {
247 Length { .. } => Some(self),
248 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
249 }
250 }
251}
252
253impl ReprColumnTypeDifference {
254 pub fn ignore_nullability(self) -> Option<Self> {
258 use ReprColumnTypeDifference::*;
259
260 match self {
261 Nullability { .. } => None,
262 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
263 ElementType { ctor, element_type } => {
264 element_type
265 .ignore_nullability()
266 .map(|element_type| ElementType {
267 ctor,
268 element_type: Box::new(element_type),
269 })
270 }
271 RecordFields { fields } => {
272 let fields = fields
273 .into_iter()
274 .flat_map(|diff| diff.ignore_nullability())
275 .collect::<Vec<_>>();
276
277 if fields.is_empty() {
278 None
279 } else {
280 Some(RecordFields { fields })
281 }
282 }
283 }
284 }
285}
286
287pub fn relation_subtype_difference(
291 sub: &[ReprColumnType],
292 sup: &[ReprColumnType],
293) -> Vec<ReprRelationTypeDifference> {
294 let mut diffs = Vec::new();
295
296 if sub.len() != sup.len() {
297 diffs.push(ReprRelationTypeDifference::Length {
298 len_sub: sub.len(),
299 len_sup: sup.len(),
300 });
301
302 return diffs;
304 }
305
306 diffs.extend(
307 sub.iter()
308 .zip_eq(sup.iter())
309 .enumerate()
310 .flat_map(|(col, (sub_ty, sup_ty))| {
311 column_subtype_difference(sub_ty, sup_ty)
312 .into_iter()
313 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
314 }),
315 );
316
317 diffs
318}
319
320pub fn column_subtype_difference(
324 sub: &ReprColumnType,
325 sup: &ReprColumnType,
326) -> Vec<ReprColumnTypeDifference> {
327 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
328
329 if sub.nullable && !sup.nullable {
330 diffs.push(ReprColumnTypeDifference::Nullability {
331 sub: sub.clone(),
332 sup: sup.clone(),
333 });
334 }
335
336 diffs
337}
338
339pub fn scalar_subtype_difference(
343 sub: &ReprScalarType,
344 sup: &ReprScalarType,
345) -> Vec<ReprColumnTypeDifference> {
346 use ReprScalarType::*;
347
348 let mut diffs = Vec::new();
349
350 match (sub, sup) {
351 (
352 List {
353 element_type: sub_elt,
354 ..
355 },
356 List {
357 element_type: sup_elt,
358 ..
359 },
360 )
361 | (
362 Map {
363 value_type: sub_elt,
364 ..
365 },
366 Map {
367 value_type: sup_elt,
368 ..
369 },
370 )
371 | (
372 Range {
373 element_type: sub_elt,
374 ..
375 },
376 Range {
377 element_type: sup_elt,
378 ..
379 },
380 )
381 | (Array(sub_elt), Array(sup_elt)) => {
382 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
383 diffs.extend(
384 scalar_subtype_difference(sub_elt, sup_elt)
385 .into_iter()
386 .map(|diff| ReprColumnTypeDifference::ElementType {
387 ctor: ctor.clone(),
388 element_type: Box::new(diff),
389 }),
390 );
391 }
392 (
393 Record {
394 fields: sub_fields, ..
395 },
396 Record {
397 fields: sup_fields, ..
398 },
399 ) => {
400 if sub_fields.len() != sup_fields.len() {
401 diffs.push(ReprColumnTypeDifference::NotSubtype {
402 sub: sub.clone(),
403 sup: sup.clone(),
404 });
405 return diffs;
406 }
407
408 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
409 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
410 }
411 }
412 (_, _) => {
413 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
414 diffs.push(ReprColumnTypeDifference::NotSubtype {
415 sub: sub.clone(),
416 sup: sup.clone(),
417 })
418 }
419 }
420 };
421
422 diffs
423}
424
425pub fn scalar_union(
429 typ: &mut ReprScalarType,
430 other: &ReprScalarType,
431) -> Vec<ReprColumnTypeDifference> {
432 use ReprScalarType::*;
433
434 let mut diffs = Vec::new();
435
436 let ctor = ReprScalarBaseType::from(&*typ);
438 match (typ, other) {
439 (
440 List {
441 element_type: typ_elt,
442 },
443 List {
444 element_type: other_elt,
445 },
446 )
447 | (
448 Map {
449 value_type: typ_elt,
450 },
451 Map {
452 value_type: other_elt,
453 },
454 )
455 | (
456 Range {
457 element_type: typ_elt,
458 },
459 Range {
460 element_type: other_elt,
461 },
462 )
463 | (Array(typ_elt), Array(other_elt)) => {
464 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
465 diffs.extend(
466 res.into_iter()
467 .map(|diff| ReprColumnTypeDifference::ElementType {
468 ctor: format!("{ctor:?}"),
469 element_type: Box::new(diff),
470 }),
471 );
472 }
473 (
474 Record { fields: typ_fields },
475 Record {
476 fields: other_fields,
477 },
478 ) => {
479 if typ_fields.len() != other_fields.len() {
480 diffs.push(ReprColumnTypeDifference::NotSubtype {
481 sub: ReprScalarType::Record {
482 fields: typ_fields.clone(),
483 },
484 sup: other.clone(),
485 });
486 return diffs;
487 }
488
489 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
490 diffs.extend(column_union(typ_ty, other_ty));
491 }
492 }
493 (typ, _) => {
494 if ctor != ReprScalarBaseType::from(other) {
495 diffs.push(ReprColumnTypeDifference::NotSubtype {
496 sub: typ.clone(),
497 sup: other.clone(),
498 })
499 }
500 }
501 };
502
503 diffs
504}
505
506pub fn column_union(
510 typ: &mut ReprColumnType,
511 other: &ReprColumnType,
512) -> Vec<ReprColumnTypeDifference> {
513 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
514
515 if diffs.is_empty() {
516 typ.nullable |= other.nullable;
517 }
518
519 diffs
520}
521
522pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
527 if sub.len() != sup.len() {
528 return false;
529 }
530
531 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
532 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
533 })
534}
535
536#[derive(Clone, Debug)]
538pub enum DatumTypeDifference {
539 Null {
541 expected: ReprScalarType,
543 },
544 Mismatch {
546 got_debug: String,
549 expected: ReprScalarType,
551 },
552 MismatchDimensions {
554 ctor: String,
556 got: usize,
558 expected: usize,
560 },
561 ElementType {
563 ctor: String,
565 element_type: Box<DatumTypeDifference>,
567 },
568}
569
570fn datum_difference_with_column_type(
576 datum: &Datum<'_>,
577 column_type: &ReprColumnType,
578) -> Result<(), DatumTypeDifference> {
579 fn difference_with_scalar_type(
580 datum: &Datum<'_>,
581 scalar_type: &ReprScalarType,
582 ) -> Result<(), DatumTypeDifference> {
583 fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
584 Err(DatumTypeDifference::Mismatch {
585 got_debug: format!("{got:?}"),
587 expected: expected.clone(),
588 })
589 }
590
591 if let ReprScalarType::Jsonb = scalar_type {
592 match datum {
594 Datum::Dummy => Ok(()), Datum::Null => Err(DatumTypeDifference::Null {
596 expected: ReprScalarType::Jsonb,
597 }),
598 Datum::JsonNull
599 | Datum::False
600 | Datum::True
601 | Datum::Numeric(_)
602 | Datum::String(_) => Ok(()),
603 Datum::List(list) => {
604 for elem in list.iter() {
605 difference_with_scalar_type(&elem, scalar_type)?;
606 }
607 Ok(())
608 }
609 Datum::Map(dict) => {
610 for (_, val) in dict.iter() {
611 difference_with_scalar_type(&val, scalar_type)?;
612 }
613 Ok(())
614 }
615 _ => mismatch(datum, scalar_type),
616 }
617 } else {
618 fn element_type_difference(
619 ctor: &str,
620 element_type: DatumTypeDifference,
621 ) -> DatumTypeDifference {
622 DatumTypeDifference::ElementType {
623 ctor: ctor.to_string(),
624 element_type: Box::new(element_type),
625 }
626 }
627 match (datum, scalar_type) {
628 (Datum::Dummy, _) => Ok(()), (Datum::Null, _) => Err(DatumTypeDifference::Null {
630 expected: scalar_type.clone(),
631 }),
632 (Datum::False, ReprScalarType::Bool) => Ok(()),
633 (Datum::False, _) => mismatch(datum, scalar_type),
634 (Datum::True, ReprScalarType::Bool) => Ok(()),
635 (Datum::True, _) => mismatch(datum, scalar_type),
636 (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
637 (Datum::Int16(_), _) => mismatch(datum, scalar_type),
638 (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
639 (Datum::Int32(_), _) => mismatch(datum, scalar_type),
640 (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
641 (Datum::Int64(_), _) => mismatch(datum, scalar_type),
642 (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
643 (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
644 (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
645 (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
646 (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
647 (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
648 (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
649 (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
650 (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
651 (Datum::Float32(_), _) => mismatch(datum, scalar_type),
652 (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
653 (Datum::Float64(_), _) => mismatch(datum, scalar_type),
654 (Datum::Date(_), ReprScalarType::Date) => Ok(()),
655 (Datum::Date(_), _) => mismatch(datum, scalar_type),
656 (Datum::Time(_), ReprScalarType::Time) => Ok(()),
657 (Datum::Time(_), _) => mismatch(datum, scalar_type),
658 (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
659 (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
660 (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
661 (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
662 (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
663 (Datum::Interval(_), _) => mismatch(datum, scalar_type),
664 (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
665 (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
666 (Datum::String(_), ReprScalarType::String) => Ok(()),
667 (Datum::String(_), _) => mismatch(datum, scalar_type),
668 (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
669 (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
670 (Datum::Array(array), ReprScalarType::Array(t)) => {
671 for e in array.elements().iter() {
672 if let Datum::Null = e {
673 continue;
674 }
675
676 difference_with_scalar_type(&e, t)
677 .map_err(|e| element_type_difference("array", e))?;
678 }
679 Ok(())
680 }
681 (Datum::Array(array), ReprScalarType::Int2Vector) => {
682 if !array.has_int2vector_dims() {
683 return Err(DatumTypeDifference::MismatchDimensions {
686 ctor: "int2vector".to_string(),
687 got: array.dims().len(),
688 expected: 1,
689 });
690 }
691
692 for e in array.elements().iter() {
693 difference_with_scalar_type(&e, &ReprScalarType::Int16)
694 .map_err(|e| element_type_difference("int2vector", e))?;
695 }
696
697 Ok(())
698 }
699 (Datum::Array(_), _) => mismatch(datum, scalar_type),
700 (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
701 for e in list.iter() {
702 if let Datum::Null = e {
703 continue;
704 }
705
706 difference_with_scalar_type(&e, element_type)
707 .map_err(|e| element_type_difference("list", e))?;
708 }
709 Ok(())
710 }
711 (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
712 let len = list.iter().count();
713 if len != fields.len() {
714 return Err(DatumTypeDifference::MismatchDimensions {
715 ctor: "record".to_string(),
716 got: len,
717 expected: fields.len(),
718 });
719 }
720
721 for (e, t) in list.iter().zip_eq(fields) {
722 if let Datum::Null = e {
723 if t.nullable {
724 continue;
725 } else {
726 return Err(DatumTypeDifference::Null {
727 expected: t.scalar_type.clone(),
728 });
729 }
730 }
731
732 difference_with_scalar_type(&e, &t.scalar_type)
733 .map_err(|e| element_type_difference("record", e))?;
734 }
735 Ok(())
736 }
737 (Datum::List(_), _) => mismatch(datum, scalar_type),
738 (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
739 for (_, v) in map.iter() {
740 if let Datum::Null = v {
741 continue;
742 }
743
744 difference_with_scalar_type(&v, value_type)
745 .map_err(|e| element_type_difference("map", e))?;
746 }
747 Ok(())
748 }
749 (Datum::Map(_), _) => mismatch(datum, scalar_type),
750 (Datum::JsonNull, _) => mismatch(datum, scalar_type),
751 (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
752 (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
753 (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
754 (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
755 (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
756 match inner {
757 None => Ok(()),
758 Some(inner) => {
759 if let Some(b) = inner.lower.bound {
760 difference_with_scalar_type(&b.datum(), element_type)
761 .map_err(|e| element_type_difference("range", e))?;
762 }
763 if let Some(b) = inner.upper.bound {
764 difference_with_scalar_type(&b.datum(), element_type)
765 .map_err(|e| element_type_difference("range", e))?;
766 }
767 Ok(())
768 }
769 }
770 }
771 (Datum::Range(_), _) => mismatch(datum, scalar_type),
772 (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
773 (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
774 (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
775 (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
776 }
777 }
778 }
779 if column_type.nullable {
780 if let Datum::Null = datum {
781 return Ok(());
782 }
783 }
784 difference_with_scalar_type(datum, &column_type.scalar_type)
785}
786
787fn row_difference_with_column_types<'a>(
788 source: &'a MirRelationExpr,
789 datums: &Vec<Datum<'_>>,
790 column_types: &[ReprColumnType],
791) -> Result<(), TypeError<'a>> {
792 if datums.len() != column_types.len() {
794 return Err(TypeError::BadConstantRowLen {
795 source,
796 got: datums.len(),
797 expected: column_types.to_vec(),
798 });
799 }
800
801 let mut mismatches = Vec::new();
803 for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
804 if let Err(e) = datum_difference_with_column_type(d, ty) {
805 mismatches.push((i, e));
806 }
807 }
808 if !mismatches.is_empty() {
809 return Err(TypeError::BadConstantRow {
810 source,
811 mismatches,
812 expected: column_types.to_vec(),
813 });
814 }
815
816 Ok(())
817}
818#[derive(Debug)]
820pub struct Typecheck {
821 ctx: SharedTypecheckingContext,
823 disallow_new_globals: bool,
825 strict_join_equivalences: bool,
827 disallow_dummy: bool,
829 recursion_guard: RecursionGuard,
831}
832
833impl CheckedRecursion for Typecheck {
834 fn recursion_guard(&self) -> &RecursionGuard {
835 &self.recursion_guard
836 }
837}
838
839impl Typecheck {
840 pub fn new(ctx: SharedTypecheckingContext) -> Self {
842 Self {
843 ctx,
844 disallow_new_globals: false,
845 strict_join_equivalences: false,
846 disallow_dummy: false,
847 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
848 }
849 }
850
851 pub fn disallow_new_globals(mut self) -> Self {
855 self.disallow_new_globals = true;
856 self
857 }
858
859 pub fn strict_join_equivalences(mut self) -> Self {
863 self.strict_join_equivalences = true;
864
865 self
866 }
867
868 pub fn disallow_dummy(mut self) -> Self {
870 self.disallow_dummy = true;
871 self
872 }
873
874 pub fn typecheck<'a>(
885 &self,
886 expr: &'a MirRelationExpr,
887 ctx: &Context,
888 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
889 use MirRelationExpr::*;
890
891 self.checked_recur(|tc| match expr {
892 Constant { typ, rows } => {
893 if let Ok(rows) = rows {
894 for (row, _id) in rows {
895 let datums = row.unpack();
896
897 let col_types = typ
898 .column_types
899 .iter()
900 .cloned()
901 .collect_vec();
902 row_difference_with_column_types(
903 expr, &datums, &col_types,
904 )?;
905
906 if self.disallow_dummy
907 && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
908 {
909 return Err(TypeError::DisallowedDummy {
910 source: expr,
911 });
912 }
913 }
914 }
915
916 Ok(typ.column_types.iter().cloned().collect_vec())
917 }
918 Get { typ, id, .. } => {
919 if let Id::Global(_global_id) = id {
920 if !ctx.contains_key(id) {
921 return Ok(typ.column_types.iter().cloned().collect_vec());
923 }
924 }
925
926 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
927 source: expr,
928 id: id.clone(),
929 typ: typ.clone(),
930 })?;
931
932 let column_types = typ.column_types.iter().cloned().collect_vec();
933
934 let diffs = relation_subtype_difference(&column_types, ctx_typ)
936 .into_iter()
937 .flat_map(|diff| diff.ignore_nullability())
938 .collect::<Vec<_>>();
939
940 if !diffs.is_empty() {
941 return Err(TypeError::MismatchColumns {
942 source: expr,
943 got: column_types,
944 expected: ctx_typ.clone(),
945 diffs,
946 message: "annotation did not match context type".to_string(),
947 });
948 }
949
950 Ok(column_types)
951 }
952 Project { input, outputs } => {
953 let t_in = tc.typecheck(input, ctx)?;
954
955 for x in outputs {
956 if *x >= t_in.len() {
957 return Err(TypeError::BadProject {
958 source: expr,
959 got: outputs.clone(),
960 input_type: t_in,
961 });
962 }
963 }
964
965 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
966 }
967 Map { input, scalars } => {
968 let mut t_in = tc.typecheck(input, ctx)?;
969
970 for scalar_expr in scalars.iter() {
971 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
972
973 if self.disallow_dummy && scalar_expr.contains_dummy() {
974 return Err(TypeError::DisallowedDummy {
975 source: expr,
976 });
977 }
978 }
979
980 Ok(t_in)
981 }
982 FlatMap { input, func, exprs } => {
983 let mut t_in = tc.typecheck(input, ctx)?;
984
985 for scalar_expr in exprs {
986 let _t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
988
989 if self.disallow_dummy && scalar_expr.contains_dummy() {
990 return Err(TypeError::DisallowedDummy {
991 source: expr,
992 });
993 }
994 }
995
996 let t_out: Vec<ReprColumnType> = func
997 .output_type().column_types;
998
999 t_in.extend(t_out);
1001 Ok(t_in)
1002 }
1003 Filter { input, predicates } => {
1004 let mut t_in = tc.typecheck(input, ctx)?;
1005
1006 for column in non_nullable_columns(predicates) {
1009 t_in[column].nullable = false;
1010 }
1011
1012 for scalar_expr in predicates {
1013 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1014
1015 if t.scalar_type != ReprScalarType::Bool {
1019 let sub = t.scalar_type.clone();
1020
1021 return Err(TypeError::MismatchColumn {
1022 source: expr,
1023 got: t,
1024 expected: ReprColumnType {
1025 scalar_type: ReprScalarType::Bool,
1026 nullable: true,
1027 },
1028 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1029 sub,
1030 sup: ReprScalarType::Bool,
1031 }],
1032 message: "expected boolean condition".to_string(),
1033 });
1034 }
1035
1036 if self.disallow_dummy && scalar_expr.contains_dummy() {
1037 return Err(TypeError::DisallowedDummy {
1038 source: expr,
1039 });
1040 }
1041 }
1042
1043 Ok(t_in)
1044 }
1045 Join {
1046 inputs,
1047 equivalences,
1048 implementation,
1049 } => {
1050 let mut t_in_global = Vec::new();
1051 let mut t_in_local = vec![Vec::new(); inputs.len()];
1052
1053 for (i, input) in inputs.iter().enumerate() {
1054 let input_t = tc.typecheck(input, ctx)?;
1055 t_in_global.extend(input_t.clone());
1056 t_in_local[i] = input_t;
1057 }
1058
1059 for eq_class in equivalences {
1060 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1061
1062 let mut all_nullable = true;
1063
1064 for scalar_expr in eq_class {
1065 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1067
1068 if !t_expr.nullable {
1069 all_nullable = false;
1070 }
1071
1072 if let Some(t_first) = t_exprs.get(0) {
1073 let diffs = scalar_subtype_difference(
1074 &t_expr.scalar_type,
1075 &t_first.scalar_type,
1076 );
1077 if !diffs.is_empty() {
1078 return Err(TypeError::MismatchColumn {
1079 source: expr,
1080 got: t_expr,
1081 expected: t_first.clone(),
1082 diffs,
1083 message: "equivalence class members \
1084 have different scalar types"
1085 .to_string(),
1086 });
1087 }
1088
1089 if self.strict_join_equivalences {
1093 if t_expr.nullable != t_first.nullable {
1094 let sub = t_expr.clone();
1095 let sup = t_first.clone();
1096
1097 let err = TypeError::MismatchColumn {
1098 source: expr,
1099 got: t_expr.clone(),
1100 expected: t_first.clone(),
1101 diffs: vec![
1102 ReprColumnTypeDifference::Nullability { sub, sup },
1103 ],
1104 message: "equivalence class members have \
1105 different nullability (and join \
1106 equivalence checking is strict)"
1107 .to_string(),
1108 };
1109
1110 ::tracing::debug!("{err}");
1112 }
1113 }
1114 }
1115
1116 if self.disallow_dummy && scalar_expr.contains_dummy() {
1117 return Err(TypeError::DisallowedDummy {
1118 source: expr,
1119 });
1120 }
1121
1122 t_exprs.push(t_expr);
1123 }
1124
1125 if self.strict_join_equivalences && all_nullable {
1126 let err = TypeError::BadJoinEquivalence {
1127 source: expr,
1128 got: t_exprs,
1129 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1130 };
1131
1132 ::tracing::debug!("{err}");
1134 }
1135 }
1136
1137 match implementation {
1139 JoinImplementation::Differential((start_idx, first_key, _), others) => {
1140 if let Some(key) = first_key {
1141 for k in key {
1142 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1143 }
1144 }
1145
1146 for (idx, key, _) in others {
1147 for k in key {
1148 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1149 }
1150 }
1151 }
1152 JoinImplementation::DeltaQuery(plans) => {
1153 for plan in plans {
1154 for (idx, key, _) in plan {
1155 for k in key {
1156 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1157 }
1158 }
1159 }
1160 }
1161 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1162 let typ: Vec<ReprColumnType> = key
1163 .iter()
1164 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1165 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1166
1167 for row in consts {
1168 let datums = row.unpack();
1169
1170 row_difference_with_column_types(expr, &datums, &typ)?;
1171 }
1172 }
1173 JoinImplementation::Unimplemented => (),
1174 }
1175
1176 Ok(t_in_global)
1177 }
1178 Reduce {
1179 input,
1180 group_key,
1181 aggregates,
1182 monotonic: _,
1183 expected_group_size: _,
1184 } => {
1185 let t_in = tc.typecheck(input, ctx)?;
1186
1187 let mut t_out = group_key
1188 .iter()
1189 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1190 .collect::<Result<Vec<_>, _>>()?;
1191
1192 if self.disallow_dummy
1193 && group_key
1194 .iter()
1195 .any(|scalar_expr| scalar_expr.contains_dummy())
1196 {
1197 return Err(TypeError::DisallowedDummy {
1198 source: expr,
1199 });
1200 }
1201
1202 for agg in aggregates {
1203 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1204 }
1205
1206 Ok(t_out)
1207 }
1208 TopK {
1209 input,
1210 group_key,
1211 order_key,
1212 limit: _,
1213 offset: _,
1214 monotonic: _,
1215 expected_group_size: _,
1216 } => {
1217 let t_in = tc.typecheck(input, ctx)?;
1218
1219 for &k in group_key {
1220 if k >= t_in.len() {
1221 return Err(TypeError::BadTopKGroupKey {
1222 source: expr,
1223 k,
1224 input_type: t_in,
1225 });
1226 }
1227 }
1228
1229 for order in order_key {
1230 if order.column >= t_in.len() {
1231 return Err(TypeError::BadTopKOrdering {
1232 source: expr,
1233 order: order.clone(),
1234 input_type: t_in,
1235 });
1236 }
1237 }
1238
1239 Ok(t_in)
1240 }
1241 Negate { input } => tc.typecheck(input, ctx),
1242 Threshold { input } => tc.typecheck(input, ctx),
1243 Union { base, inputs } => {
1244 let mut t_base = tc.typecheck(base, ctx)?;
1245
1246 for input in inputs {
1247 let t_input = tc.typecheck(input, ctx)?;
1248
1249 let len_sub = t_base.len();
1250 let len_sup = t_input.len();
1251 if len_sub != len_sup {
1252 return Err(TypeError::MismatchColumns {
1253 source: expr,
1254 got: t_base.clone(),
1255 expected: t_input,
1256 diffs: vec![ReprRelationTypeDifference::Length {
1257 len_sub,
1258 len_sup,
1259 }],
1260 message: "Union branches have different numbers of columns".to_string(),
1261 });
1262 }
1263
1264 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1265 let diffs = column_union(base_col, &input_col);
1266 if !diffs.is_empty() {
1267 return Err(TypeError::MismatchColumn {
1268 source: expr,
1269 got: input_col,
1270 expected: base_col.clone(),
1271 diffs,
1272 message:
1273 "couldn't compute union of column types in Union"
1274 .to_string(),
1275 });
1276 }
1277
1278 }
1279 }
1280
1281 Ok(t_base)
1282 }
1283 Let { id, value, body } => {
1284 let t_value = tc.typecheck(value, ctx)?;
1285
1286 let binding = Id::Local(*id);
1287 if ctx.contains_key(&binding) {
1288 return Err(TypeError::Shadowing {
1289 source: expr,
1290 id: binding,
1291 });
1292 }
1293
1294 let mut body_ctx = ctx.clone();
1295 body_ctx.insert(Id::Local(*id), t_value);
1296
1297 tc.typecheck(body, &body_ctx)
1298 }
1299 LetRec { ids, values, body, limits: _ } => {
1300 if ids.len() != values.len() {
1301 return Err(TypeError::BadLetRecBindings { source: expr });
1302 }
1303
1304 let mut ctx = ctx.clone();
1307 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1309 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1310 }
1311
1312 for (id, value) in ids.iter().zip_eq(values.iter()) {
1313 let typ = tc.typecheck(value, &ctx)?;
1314
1315 let id = Id::Local(id.clone());
1316 if let Some(ctx_typ) = ctx.get_mut(&id) {
1317 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1318 let diffs = column_union(base_col, &input_col);
1320 if !diffs.is_empty() {
1321 return Err(TypeError::MismatchColumn {
1322 source: expr,
1323 got: input_col,
1324 expected: base_col.clone(),
1325 diffs,
1326 message:
1327 "couldn't compute union of column types in LetRec"
1328 .to_string(),
1329 })
1330 }
1331 }
1332 } else {
1333 ctx.insert(id, typ);
1335 }
1336 }
1337
1338 tc.typecheck(body, &ctx)
1339 }
1340 ArrangeBy { input, keys } => {
1341 let t_in = tc.typecheck(input, ctx)?;
1342
1343 for key in keys {
1344 for k in key {
1345 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1346 }
1347 }
1348
1349 Ok(t_in)
1350 }
1351 })
1352 }
1353
1354 fn collect_recursive_variable_types<'a>(
1358 &self,
1359 expr: &'a MirRelationExpr,
1360 ids: &[LocalId],
1361 ctx: &mut Context,
1362 ) -> Result<(), TypeError<'a>> {
1363 use MirRelationExpr::*;
1364
1365 self.checked_recur(|tc| {
1366 match expr {
1367 Get {
1368 id: Id::Local(id),
1369 typ,
1370 ..
1371 } => {
1372 if !ids.contains(id) {
1373 return Ok(());
1374 }
1375
1376 let id = Id::Local(id.clone());
1377 if let Some(ctx_typ) = ctx.get_mut(&id) {
1378 let typ = typ.column_types.iter().cloned().collect_vec();
1379
1380 if ctx_typ.len() != typ.len() {
1381 let diffs = relation_subtype_difference(&typ, ctx_typ);
1382
1383 return Err(TypeError::MismatchColumns {
1384 source: expr,
1385 got: typ,
1386 expected: ctx_typ.clone(),
1387 diffs,
1388 message: "environment and type annotation did not match"
1389 .to_string(),
1390 });
1391 }
1392
1393 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1394 let diffs = column_union(base_col, &input_col);
1395 if !diffs.is_empty() {
1396 return Err(TypeError::MismatchColumn {
1397 source: expr,
1398 got: input_col,
1399 expected: base_col.clone(),
1400 diffs,
1401 message:
1402 "couldn't compute union of column types in Get and context"
1403 .to_string(),
1404 });
1405 }
1406 }
1407 } else {
1408 ctx.insert(id, typ.column_types.iter().cloned().collect_vec());
1409 }
1410 }
1411 Get {
1412 id: Id::Global(..), ..
1413 }
1414 | Constant { .. } => (),
1415 Let { id, value, body } => {
1416 tc.collect_recursive_variable_types(value, ids, ctx)?;
1417
1418 if ids.contains(id) {
1420 return Err(TypeError::Shadowing {
1421 source: expr,
1422 id: Id::Local(*id),
1423 });
1424 }
1425
1426 tc.collect_recursive_variable_types(body, ids, ctx)?;
1427 }
1428 LetRec {
1429 ids: inner_ids,
1430 values,
1431 body,
1432 limits: _,
1433 } => {
1434 for inner_id in inner_ids {
1435 if ids.contains(inner_id) {
1436 return Err(TypeError::Shadowing {
1437 source: expr,
1438 id: Id::Local(*inner_id),
1439 });
1440 }
1441 }
1442
1443 for value in values {
1444 tc.collect_recursive_variable_types(value, ids, ctx)?;
1445 }
1446
1447 tc.collect_recursive_variable_types(body, ids, ctx)?;
1448 }
1449 Project { input, .. }
1450 | Map { input, .. }
1451 | FlatMap { input, .. }
1452 | Filter { input, .. }
1453 | Reduce { input, .. }
1454 | TopK { input, .. }
1455 | Negate { input }
1456 | Threshold { input }
1457 | ArrangeBy { input, .. } => {
1458 tc.collect_recursive_variable_types(input, ids, ctx)?;
1459 }
1460 Join { inputs, .. } => {
1461 for input in inputs {
1462 tc.collect_recursive_variable_types(input, ids, ctx)?;
1463 }
1464 }
1465 Union { base, inputs } => {
1466 tc.collect_recursive_variable_types(base, ids, ctx)?;
1467
1468 for input in inputs {
1469 tc.collect_recursive_variable_types(input, ids, ctx)?;
1470 }
1471 }
1472 }
1473
1474 Ok(())
1475 })
1476 }
1477
1478 fn typecheck_scalar<'a>(
1479 &self,
1480 expr: &'a MirScalarExpr,
1481 source: &'a MirRelationExpr,
1482 column_types: &[ReprColumnType],
1483 ) -> Result<ReprColumnType, TypeError<'a>> {
1484 use MirScalarExpr::*;
1485
1486 self.checked_recur(|tc| match expr {
1487 Column(i, _) => match column_types.get(*i) {
1488 Some(ty) => Ok(ty.clone()),
1489 None => Err(TypeError::NoSuchColumn {
1490 source,
1491 expr,
1492 col: *i,
1493 }),
1494 },
1495 Literal(row, typ) => {
1496 let typ = typ.clone();
1497 if let Ok(row) = row {
1498 let datums = row.unpack();
1499
1500 row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1501 }
1502
1503 Ok(typ)
1504 }
1505 CallUnmaterializable(func) => Ok(func.output_type()),
1506 CallUnary { expr, func } => {
1507 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1508 let typ_out = func.output_type(typ_in);
1509 Ok(typ_out)
1510 }
1511 CallBinary { expr1, expr2, func } => {
1512 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1513 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1514 let typ_out = func.output_type(&[typ_in1, typ_in2]);
1515 Ok(typ_out)
1516 }
1517 CallVariadic { exprs, func } => Ok(func.output_type(
1518 exprs
1519 .iter()
1520 .map(|e| tc.typecheck_scalar(e, source, column_types))
1521 .collect::<Result<Vec<_>, TypeError>>()?,
1522 )),
1523 If { cond, then, els } => {
1524 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1525
1526 if cond_type.scalar_type != ReprScalarType::Bool {
1530 let sub = cond_type.scalar_type.clone();
1531
1532 return Err(TypeError::MismatchColumn {
1533 source,
1534 got: cond_type,
1535 expected: ReprColumnType {
1536 scalar_type: ReprScalarType::Bool,
1537 nullable: true,
1538 },
1539 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1540 sub,
1541 sup: ReprScalarType::Bool,
1542 }],
1543 message: "expected boolean condition".to_string(),
1544 });
1545 }
1546
1547 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1548 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1549
1550 let diffs = column_union(&mut then_type, &else_type);
1551 if !diffs.is_empty() {
1552 return Err(TypeError::MismatchColumn {
1553 source,
1554 got: then_type,
1555 expected: else_type,
1556 diffs,
1557 message: "couldn't compute union of column types for If".to_string(),
1558 });
1559 }
1560
1561 Ok(then_type)
1562 }
1563 })
1564 }
1565
1566 pub fn typecheck_aggregate<'a>(
1568 &self,
1569 expr: &'a AggregateExpr,
1570 source: &'a MirRelationExpr,
1571 column_types: &[ReprColumnType],
1572 ) -> Result<ReprColumnType, TypeError<'a>> {
1573 self.checked_recur(|tc| {
1574 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1575
1576 Ok(expr.func.output_type(t_in))
1579 })
1580 }
1581}
1582
1583macro_rules! type_error {
1587 ($severity:expr, $($arg:tt)+) => {{
1588 if $severity {
1589 soft_panic_or_log!($($arg)+);
1590 } else {
1591 ::tracing::debug!($($arg)+);
1592 }
1593 }}
1594}
1595
1596impl crate::Transform for Typecheck {
1597 fn name(&self) -> &'static str {
1598 "Typecheck"
1599 }
1600
1601 fn actually_perform_transform(
1602 &self,
1603 relation: &mut MirRelationExpr,
1604 transform_ctx: &mut crate::TransformCtx,
1605 ) -> Result<(), crate::TransformError> {
1606 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1607
1608 let expected = transform_ctx
1609 .global_id
1610 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1611
1612 if let Some(id) = transform_ctx.global_id {
1613 if self.disallow_new_globals
1614 && expected.is_none()
1615 && transform_ctx.global_id.is_some()
1616 && !id.is_transient()
1617 {
1618 type_error!(
1619 false, "type warning: new non-transient global id {id}\n{}",
1621 relation.pretty()
1622 );
1623 }
1624 }
1625
1626 let got = self.typecheck(relation, &typecheck_ctx);
1627
1628 let humanizer = mz_repr::explain::DummyHumanizer;
1629
1630 match (got, expected) {
1631 (Ok(got), Some(expected)) => {
1632 let id = transform_ctx.global_id.unwrap();
1633
1634 let diffs = relation_subtype_difference(expected, &got);
1636 if !diffs.is_empty() {
1637 let severity = diffs
1639 .iter()
1640 .any(|diff| diff.clone().ignore_nullability().is_some());
1641
1642 let err = TypeError::MismatchColumns {
1643 source: relation,
1644 got,
1645 expected: expected.clone(),
1646 diffs,
1647 message: format!(
1648 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1649 ),
1650 };
1651
1652 type_error!(severity, "type error in known global id {id}:\n{err}");
1653 }
1654 }
1655 (Ok(got), None) => {
1656 if let Some(id) = transform_ctx.global_id {
1657 typecheck_ctx.insert(Id::Global(id), got);
1658 }
1659 }
1660 (Err(err), _) => {
1661 let (expected, binding) = match expected {
1662 Some(expected) => {
1663 let id = transform_ctx.global_id.unwrap();
1664 (
1665 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1666 format!("known global id {id}"),
1667 )
1668 }
1669 None => ("".to_string(), "transient query".to_string()),
1670 };
1671
1672 type_error!(
1673 true, "type error in {binding}:\n{err}\n{expected}{}",
1675 relation.pretty()
1676 );
1677 }
1678 }
1679
1680 Ok(())
1681 }
1682}
1683
1684pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1686where
1687 H: ExprHumanizer,
1688{
1689 let mut s = String::with_capacity(2 + 3 * cols.len());
1690
1691 s.push('(');
1692
1693 let mut it = cols.iter().peekable();
1694 while let Some(col) = it.next() {
1695 s.push_str(&humanizer.humanize_column_type(col));
1696
1697 if it.peek().is_some() {
1698 s.push_str(", ");
1699 }
1700 }
1701
1702 s.push(')');
1703
1704 s
1705}
1706
1707impl ReprRelationTypeDifference {
1708 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1712 where
1713 H: ExprHumanizer,
1714 {
1715 use ReprRelationTypeDifference::*;
1716 match self {
1717 Length { len_sub, len_sup } => {
1718 writeln!(
1719 f,
1720 " number of columns do not match ({len_sub} != {len_sup})"
1721 )
1722 }
1723 Column { col, diff } => {
1724 writeln!(f, " column {col} differs:")?;
1725 diff.humanize(4, h, f)
1726 }
1727 }
1728 }
1729}
1730
1731impl ReprColumnTypeDifference {
1732 pub fn humanize<H>(
1734 &self,
1735 indent: usize,
1736 h: &H,
1737 f: &mut std::fmt::Formatter<'_>,
1738 ) -> std::fmt::Result
1739 where
1740 H: ExprHumanizer,
1741 {
1742 use ReprColumnTypeDifference::*;
1743
1744 write!(f, "{:indent$}", "")?;
1746
1747 match self {
1748 NotSubtype { sub, sup } => {
1749 let sub = h.humanize_scalar_type(sub);
1750 let sup = h.humanize_scalar_type(sup);
1751
1752 writeln!(f, "{sub} is a not a subtype of {sup}")
1753 }
1754 Nullability { sub, sup } => {
1755 let sub = h.humanize_column_type(sub);
1756 let sup = h.humanize_column_type(sup);
1757
1758 writeln!(f, "{sub} is nullable but {sup} is not")
1759 }
1760 ElementType { ctor, element_type } => {
1761 writeln!(f, "{ctor} element types differ:")?;
1762
1763 element_type.humanize(indent + 2, h, f)
1764 }
1765 RecordMissingFields { missing } => {
1766 write!(f, "missing column fields:")?;
1767 for col in missing {
1768 write!(f, " {col}")?;
1769 }
1770 f.write_char('\n')
1771 }
1772 RecordFields { fields } => {
1773 writeln!(f, "{} record fields differ:", fields.len())?;
1774
1775 for (i, diff) in fields.iter().enumerate() {
1776 writeln!(f, "{:indent$} field {i}:", "")?;
1777 diff.humanize(indent + 4, h, f)?;
1778 }
1779 Ok(())
1780 }
1781 }
1782 }
1783}
1784
1785impl DatumTypeDifference {
1786 pub fn humanize<H>(
1788 &self,
1789 indent: usize,
1790 h: &H,
1791 f: &mut std::fmt::Formatter<'_>,
1792 ) -> std::fmt::Result
1793 where
1794 H: ExprHumanizer,
1795 {
1796 write!(f, "{:indent$}", "")?;
1798
1799 match self {
1800 DatumTypeDifference::Null { expected } => {
1801 let expected = h.humanize_scalar_type(expected);
1802 writeln!(
1803 f,
1804 "unexpected null, expected representation type {expected}"
1805 )?
1806 }
1807 DatumTypeDifference::Mismatch {
1808 got_debug,
1809 expected,
1810 } => {
1811 let expected = h.humanize_scalar_type(expected);
1812 writeln!(
1814 f,
1815 "got datum {got_debug}, expected representation type {expected}"
1816 )?;
1817 }
1818 DatumTypeDifference::MismatchDimensions {
1819 ctor,
1820 got,
1821 expected,
1822 } => {
1823 writeln!(
1824 f,
1825 "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1826 )?;
1827 }
1828 DatumTypeDifference::ElementType { ctor, element_type } => {
1829 writeln!(f, "{ctor} element types differ:")?;
1830 element_type.humanize(indent + 4, h, f)?;
1831 }
1832 }
1833
1834 Ok(())
1835 }
1836}
1837
1838#[allow(missing_debug_implementations)]
1840pub struct TypeErrorHumanizer<'a, 'b, H>
1841where
1842 H: ExprHumanizer,
1843{
1844 err: &'a TypeError<'a>,
1845 humanizer: &'b H,
1846}
1847
1848impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1849where
1850 H: ExprHumanizer,
1851{
1852 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1854 Self { err, humanizer }
1855 }
1856}
1857
1858impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1859where
1860 H: ExprHumanizer,
1861{
1862 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1863 self.err.humanize(self.humanizer, f)
1864 }
1865}
1866
1867impl<'a> std::fmt::Display for TypeError<'a> {
1868 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1869 TypeErrorHumanizer {
1870 err: self,
1871 humanizer: &DummyHumanizer,
1872 }
1873 .fmt(f)
1874 }
1875}
1876
1877impl<'a> TypeError<'a> {
1878 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1880 use TypeError::*;
1881 match self {
1882 Unbound { source, .. }
1883 | NoSuchColumn { source, .. }
1884 | MismatchColumn { source, .. }
1885 | MismatchColumns { source, .. }
1886 | BadConstantRowLen { source, .. }
1887 | BadConstantRow { source, .. }
1888 | BadProject { source, .. }
1889 | BadJoinEquivalence { source, .. }
1890 | BadTopKGroupKey { source, .. }
1891 | BadTopKOrdering { source, .. }
1892 | BadLetRecBindings { source }
1893 | Shadowing { source, .. }
1894 | DisallowedDummy { source, .. } => Some(source),
1895 Recursion { .. } => None,
1896 }
1897 }
1898
1899 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1900 where
1901 H: ExprHumanizer,
1902 {
1903 if let Some(source) = self.source() {
1904 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1905 }
1906
1907 use TypeError::*;
1908 match self {
1909 Unbound { source: _, id, typ } => {
1910 let typ = columns_pretty(&typ.column_types, humanizer);
1911 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1912 }
1913 NoSuchColumn {
1914 source: _,
1915 expr,
1916 col,
1917 } => writeln!(f, "{expr} references non-existent column {col}")?,
1918 MismatchColumn {
1919 source: _,
1920 got,
1921 expected,
1922 diffs,
1923 message,
1924 } => {
1925 let got = humanizer.humanize_column_type(got);
1926 let expected = humanizer.humanize_column_type(expected);
1927 writeln!(
1928 f,
1929 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1930 )?;
1931
1932 for diff in diffs {
1933 diff.humanize(2, humanizer, f)?;
1934 }
1935 }
1936 MismatchColumns {
1937 source: _,
1938 got,
1939 expected,
1940 diffs,
1941 message,
1942 } => {
1943 let got = columns_pretty(got, humanizer);
1944 let expected = columns_pretty(expected, humanizer);
1945
1946 writeln!(
1947 f,
1948 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1949 )?;
1950
1951 for diff in diffs {
1952 diff.humanize(humanizer, f)?;
1953 }
1954 }
1955 BadConstantRowLen {
1956 source: _,
1957 got,
1958 expected,
1959 } => {
1960 let expected = columns_pretty(expected, humanizer);
1961 writeln!(
1962 f,
1963 "bad constant row\n row has length {got}\nexpected row of type {expected}"
1964 )?
1965 }
1966 BadConstantRow {
1967 source: _,
1968 mismatches,
1969 expected,
1970 } => {
1971 let expected = columns_pretty(expected, humanizer);
1972
1973 let num_mismatches = mismatches.len();
1974 let plural = if num_mismatches == 1 { "" } else { "es" };
1975 writeln!(
1976 f,
1977 "bad constant row\n got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1978 )?;
1979
1980 if num_mismatches > 0 {
1981 writeln!(f, "")?;
1982 for (col, diff) in mismatches.iter() {
1983 writeln!(f, " column #{col}:")?;
1984 diff.humanize(8, humanizer, f)?;
1985 }
1986 }
1987 }
1988 BadProject {
1989 source: _,
1990 got,
1991 input_type,
1992 } => {
1993 let input_type = columns_pretty(input_type, humanizer);
1994
1995 writeln!(
1996 f,
1997 "projection of non-existant columns {got:?} from type {input_type}"
1998 )?
1999 }
2000 BadJoinEquivalence {
2001 source: _,
2002 got,
2003 message,
2004 } => {
2005 let got = columns_pretty(got, humanizer);
2006
2007 writeln!(f, "bad join equivalence {got}: {message}")?
2008 }
2009 BadTopKGroupKey {
2010 source: _,
2011 k,
2012 input_type,
2013 } => {
2014 let input_type = columns_pretty(input_type, humanizer);
2015
2016 writeln!(
2017 f,
2018 "TopK group key component references invalid column {k} in columns: {input_type}"
2019 )?
2020 }
2021 BadTopKOrdering {
2022 source: _,
2023 order,
2024 input_type,
2025 } => {
2026 let col = order.column;
2027 let num_cols = input_type.len();
2028 let are = if num_cols == 1 { "is" } else { "are" };
2029 let s = if num_cols == 1 { "" } else { "s" };
2030 let input_type = columns_pretty(input_type, humanizer);
2031
2032 let mode = HumanizedExplain::new(false);
2034 let order = mode.expr(order, None);
2035
2036 writeln!(
2037 f,
2038 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2039 )?
2040 }
2041 BadLetRecBindings { source: _ } => {
2042 writeln!(f, "LetRec ids and definitions don't line up")?
2043 }
2044 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2045 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2046 Recursion { error } => writeln!(f, "{error}")?,
2047 }
2048
2049 Ok(())
2050 }
2051}
2052
2053#[cfg(test)]
2054mod tests {
2055 use mz_ore::{assert_err, assert_ok};
2056 use mz_repr::{SqlColumnType, arb_datum, arb_datum_for_column};
2057 use proptest::prelude::*;
2058
2059 use super::*;
2060
2061 #[mz_ore::test]
2062 fn test_datum_type_difference() {
2063 let datum = Datum::Int16(1);
2064
2065 assert_ok!(datum_difference_with_column_type(
2066 &datum,
2067 &ReprColumnType {
2068 scalar_type: ReprScalarType::Int16,
2069 nullable: true,
2070 }
2071 ));
2072
2073 assert_err!(datum_difference_with_column_type(
2074 &datum,
2075 &ReprColumnType {
2076 scalar_type: ReprScalarType::Int32,
2077 nullable: false,
2078 }
2079 ));
2080 }
2081
2082 proptest! {
2083 #![proptest_config(ProptestConfig {
2084 cases: 5000,
2085 max_global_rejects: 2500,
2086 ..Default::default()
2087 })]
2088 #[mz_ore::test]
2089 #[cfg_attr(miri, ignore)]
2090 fn datum_type_difference_with_instance_of_on_valid_data(
2091 (src, datum) in any::<SqlColumnType>()
2092 .prop_flat_map(|src| {
2093 let datum = arb_datum_for_column(src.clone());
2094 (Just(src), datum)
2095 })
2096 ) {
2097 let typ = ReprColumnType::from(&src);
2098 let datum = Datum::from(&datum);
2099
2100 if datum.contains_dummy() {
2101 return Err(TestCaseError::reject("datum contains a dummy"));
2102 }
2103
2104 let diff = datum_difference_with_column_type(&datum, &typ);
2105 if datum.is_instance_of(&typ) {
2106 assert_ok!(diff);
2107 } else {
2108 assert_err!(diff);
2109 }
2110 }
2111 }
2112
2113 proptest! {
2114 #![proptest_config(ProptestConfig::with_cases(10000))]
2117 #[mz_ore::test]
2118 #[cfg_attr(miri, ignore)]
2119 fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2120 src in any::<SqlColumnType>(),
2121 datum in arb_datum(false),
2122 ) {
2123 let typ = ReprColumnType::from(&src);
2124 let datum = Datum::from(&datum);
2125
2126 assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2127
2128 let diff = datum_difference_with_column_type(&datum, &typ);
2129 if datum.is_instance_of(&typ) {
2130 assert_ok!(diff);
2131 } else {
2132 assert_err!(diff);
2133 }
2134 }
2135 }
2136
2137 #[mz_ore::test]
2138 fn datum_type_difference_github_10039() {
2139 let typ = ReprColumnType {
2140 scalar_type: ReprScalarType::Record {
2141 fields: Box::new([ReprColumnType {
2142 scalar_type: ReprScalarType::UInt32,
2143 nullable: false,
2144 }]),
2145 },
2146 nullable: false,
2147 };
2148
2149 let mut row = mz_repr::Row::default();
2150 row.packer()
2151 .push_list(std::iter::once(mz_repr::Datum::Null));
2152 let datum = row.unpack_first();
2153
2154 assert!(!datum.is_instance_of(&typ));
2155 let diff = datum_difference_with_column_type(&datum, &typ);
2156 assert_err!(diff);
2157 }
2158}