1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27 ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28 SqlColumnType,
29};
30
31pub type SharedContext = Arc<Mutex<Context>>;
36
37pub fn empty_context() -> SharedContext {
39 Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48 Unbound {
50 source: &'a MirRelationExpr,
52 id: Id,
54 typ: ReprRelationType,
56 },
57 NoSuchColumn {
59 source: &'a MirRelationExpr,
61 expr: &'a MirScalarExpr,
63 col: usize,
65 },
66 MismatchColumn {
68 source: &'a MirRelationExpr,
70 got: ReprColumnType,
72 expected: ReprColumnType,
74 diffs: Vec<ReprColumnTypeDifference>,
76 message: String,
78 },
79 MismatchColumns {
81 source: &'a MirRelationExpr,
83 got: Vec<ReprColumnType>,
85 expected: Vec<ReprColumnType>,
87 diffs: Vec<ReprRelationTypeDifference>,
89 message: String,
91 },
92 BadConstantRowLen {
94 source: &'a MirRelationExpr,
96 got: usize,
98 expected: Vec<ReprColumnType>,
100 },
101 BadConstantRow {
103 source: &'a MirRelationExpr,
105 mismatches: Vec<(usize, DatumTypeDifference)>,
107 expected: Vec<ReprColumnType>,
109 },
112 BadProject {
114 source: &'a MirRelationExpr,
116 got: Vec<usize>,
118 input_type: Vec<ReprColumnType>,
120 },
121 BadJoinEquivalence {
123 source: &'a MirRelationExpr,
125 got: Vec<ReprColumnType>,
127 message: String,
129 },
130 BadTopKGroupKey {
132 source: &'a MirRelationExpr,
134 k: usize,
136 input_type: Vec<ReprColumnType>,
138 },
139 BadTopKOrdering {
141 source: &'a MirRelationExpr,
143 order: ColumnOrder,
145 input_type: Vec<ReprColumnType>,
147 },
148 BadLetRecBindings {
150 source: &'a MirRelationExpr,
152 },
153 Shadowing {
155 source: &'a MirRelationExpr,
157 id: Id,
159 },
160 Recursion {
162 error: RecursionLimitError,
164 },
165 DisallowedDummy {
167 source: &'a MirRelationExpr,
169 },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173 fn from(error: RecursionLimitError) -> Self {
174 TypeError::Recursion { error }
175 }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185 Length {
187 len_sub: usize,
189 len_sup: usize,
191 },
192 Column {
194 col: usize,
196 diff: ReprColumnTypeDifference,
198 },
199}
200
201#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207 NotSubtype {
209 sub: ReprScalarType,
211 sup: ReprScalarType,
213 },
214 Nullability {
216 sub: ReprColumnType,
218 sup: ReprColumnType,
220 },
221 ElementType {
223 ctor: String,
225 element_type: Box<ReprColumnTypeDifference>,
227 },
228 RecordMissingFields {
230 missing: Vec<ColumnName>,
232 },
233 RecordFields {
235 fields: Vec<ReprColumnTypeDifference>,
237 },
238}
239
240impl ReprRelationTypeDifference {
241 pub fn ignore_nullability(self) -> Option<Self> {
245 use ReprRelationTypeDifference::*;
246
247 match self {
248 Length { .. } => Some(self),
249 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250 }
251 }
252}
253
254impl ReprColumnTypeDifference {
255 pub fn ignore_nullability(self) -> Option<Self> {
259 use ReprColumnTypeDifference::*;
260
261 match self {
262 Nullability { .. } => None,
263 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264 ElementType { ctor, element_type } => {
265 element_type
266 .ignore_nullability()
267 .map(|element_type| ElementType {
268 ctor,
269 element_type: Box::new(element_type),
270 })
271 }
272 RecordFields { fields } => {
273 let fields = fields
274 .into_iter()
275 .flat_map(|diff| diff.ignore_nullability())
276 .collect::<Vec<_>>();
277
278 if fields.is_empty() {
279 None
280 } else {
281 Some(RecordFields { fields })
282 }
283 }
284 }
285 }
286}
287
288pub fn relation_subtype_difference(
292 sub: &[ReprColumnType],
293 sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295 let mut diffs = Vec::new();
296
297 if sub.len() != sup.len() {
298 diffs.push(ReprRelationTypeDifference::Length {
299 len_sub: sub.len(),
300 len_sup: sup.len(),
301 });
302
303 return diffs;
305 }
306
307 diffs.extend(
308 sub.iter()
309 .zip_eq(sup.iter())
310 .enumerate()
311 .flat_map(|(col, (sub_ty, sup_ty))| {
312 column_subtype_difference(sub_ty, sup_ty)
313 .into_iter()
314 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315 }),
316 );
317
318 diffs
319}
320
321pub fn column_subtype_difference(
325 sub: &ReprColumnType,
326 sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330 if sub.nullable && !sup.nullable {
331 diffs.push(ReprColumnTypeDifference::Nullability {
332 sub: sub.clone(),
333 sup: sup.clone(),
334 });
335 }
336
337 diffs
338}
339
340pub fn scalar_subtype_difference(
344 sub: &ReprScalarType,
345 sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347 use ReprScalarType::*;
348
349 let mut diffs = Vec::new();
350
351 match (sub, sup) {
352 (
353 List {
354 element_type: sub_elt,
355 ..
356 },
357 List {
358 element_type: sup_elt,
359 ..
360 },
361 )
362 | (
363 Map {
364 value_type: sub_elt,
365 ..
366 },
367 Map {
368 value_type: sup_elt,
369 ..
370 },
371 )
372 | (
373 Range {
374 element_type: sub_elt,
375 ..
376 },
377 Range {
378 element_type: sup_elt,
379 ..
380 },
381 )
382 | (Array(sub_elt), Array(sup_elt)) => {
383 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384 diffs.extend(
385 scalar_subtype_difference(sub_elt, sup_elt)
386 .into_iter()
387 .map(|diff| ReprColumnTypeDifference::ElementType {
388 ctor: ctor.clone(),
389 element_type: Box::new(diff),
390 }),
391 );
392 }
393 (
394 Record {
395 fields: sub_fields, ..
396 },
397 Record {
398 fields: sup_fields, ..
399 },
400 ) => {
401 if sub_fields.len() != sup_fields.len() {
402 diffs.push(ReprColumnTypeDifference::NotSubtype {
403 sub: sub.clone(),
404 sup: sup.clone(),
405 });
406 return diffs;
407 }
408
409 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411 }
412 }
413 (_, _) => {
414 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415 diffs.push(ReprColumnTypeDifference::NotSubtype {
416 sub: sub.clone(),
417 sup: sup.clone(),
418 })
419 }
420 }
421 };
422
423 diffs
424}
425
426pub fn scalar_union(
430 typ: &mut ReprScalarType,
431 other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433 use ReprScalarType::*;
434
435 let mut diffs = Vec::new();
436
437 let ctor = ReprScalarBaseType::from(&*typ);
439 match (typ, other) {
440 (
441 List {
442 element_type: typ_elt,
443 },
444 List {
445 element_type: other_elt,
446 },
447 )
448 | (
449 Map {
450 value_type: typ_elt,
451 },
452 Map {
453 value_type: other_elt,
454 },
455 )
456 | (
457 Range {
458 element_type: typ_elt,
459 },
460 Range {
461 element_type: other_elt,
462 },
463 )
464 | (Array(typ_elt), Array(other_elt)) => {
465 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466 diffs.extend(
467 res.into_iter()
468 .map(|diff| ReprColumnTypeDifference::ElementType {
469 ctor: format!("{ctor:?}"),
470 element_type: Box::new(diff),
471 }),
472 );
473 }
474 (
475 Record { fields: typ_fields },
476 Record {
477 fields: other_fields,
478 },
479 ) => {
480 if typ_fields.len() != other_fields.len() {
481 diffs.push(ReprColumnTypeDifference::NotSubtype {
482 sub: ReprScalarType::Record {
483 fields: typ_fields.clone(),
484 },
485 sup: other.clone(),
486 });
487 return diffs;
488 }
489
490 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491 diffs.extend(column_union(typ_ty, other_ty));
492 }
493 }
494 (typ, _) => {
495 if ctor != ReprScalarBaseType::from(other) {
496 diffs.push(ReprColumnTypeDifference::NotSubtype {
497 sub: typ.clone(),
498 sup: other.clone(),
499 })
500 }
501 }
502 };
503
504 diffs
505}
506
507pub fn column_union(
511 typ: &mut ReprColumnType,
512 other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516 if diffs.is_empty() {
517 typ.nullable |= other.nullable;
518 }
519
520 diffs
521}
522
523pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528 if sub.len() != sup.len() {
529 return false;
530 }
531
532 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534 })
535}
536
537#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540 Null {
542 expected: ReprScalarType,
544 },
545 Mismatch {
547 got_debug: String,
550 expected: ReprScalarType,
552 },
553 MismatchDimensions {
555 ctor: String,
557 got: usize,
559 expected: usize,
561 },
562 ElementType {
564 ctor: String,
566 element_type: Box<DatumTypeDifference>,
568 },
569}
570
571fn datum_difference_with_column_type(
577 datum: &Datum<'_>,
578 column_type: &ReprColumnType,
579) -> Result<(), DatumTypeDifference> {
580 fn difference_with_scalar_type(
581 datum: &Datum<'_>,
582 scalar_type: &ReprScalarType,
583 ) -> Result<(), DatumTypeDifference> {
584 fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
585 Err(DatumTypeDifference::Mismatch {
586 got_debug: format!("{got:?}"),
588 expected: expected.clone(),
589 })
590 }
591
592 if let ReprScalarType::Jsonb = scalar_type {
593 match datum {
595 Datum::Dummy => Ok(()), Datum::Null => Err(DatumTypeDifference::Null {
597 expected: ReprScalarType::Jsonb,
598 }),
599 Datum::JsonNull
600 | Datum::False
601 | Datum::True
602 | Datum::Numeric(_)
603 | Datum::String(_) => Ok(()),
604 Datum::List(list) => {
605 for elem in list.iter() {
606 difference_with_scalar_type(&elem, scalar_type)?;
607 }
608 Ok(())
609 }
610 Datum::Map(dict) => {
611 for (_, val) in dict.iter() {
612 difference_with_scalar_type(&val, scalar_type)?;
613 }
614 Ok(())
615 }
616 _ => mismatch(datum, scalar_type),
617 }
618 } else {
619 fn element_type_difference(
620 ctor: &str,
621 element_type: DatumTypeDifference,
622 ) -> DatumTypeDifference {
623 DatumTypeDifference::ElementType {
624 ctor: ctor.to_string(),
625 element_type: Box::new(element_type),
626 }
627 }
628 match (datum, scalar_type) {
629 (Datum::Dummy, _) => Ok(()), (Datum::Null, _) => Err(DatumTypeDifference::Null {
631 expected: scalar_type.clone(),
632 }),
633 (Datum::False, ReprScalarType::Bool) => Ok(()),
634 (Datum::False, _) => mismatch(datum, scalar_type),
635 (Datum::True, ReprScalarType::Bool) => Ok(()),
636 (Datum::True, _) => mismatch(datum, scalar_type),
637 (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
638 (Datum::Int16(_), _) => mismatch(datum, scalar_type),
639 (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
640 (Datum::Int32(_), _) => mismatch(datum, scalar_type),
641 (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
642 (Datum::Int64(_), _) => mismatch(datum, scalar_type),
643 (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
644 (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
645 (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
646 (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
647 (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
648 (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
649 (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
650 (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
651 (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
652 (Datum::Float32(_), _) => mismatch(datum, scalar_type),
653 (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
654 (Datum::Float64(_), _) => mismatch(datum, scalar_type),
655 (Datum::Date(_), ReprScalarType::Date) => Ok(()),
656 (Datum::Date(_), _) => mismatch(datum, scalar_type),
657 (Datum::Time(_), ReprScalarType::Time) => Ok(()),
658 (Datum::Time(_), _) => mismatch(datum, scalar_type),
659 (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
660 (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
661 (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
662 (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
663 (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
664 (Datum::Interval(_), _) => mismatch(datum, scalar_type),
665 (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
666 (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
667 (Datum::String(_), ReprScalarType::String) => Ok(()),
668 (Datum::String(_), _) => mismatch(datum, scalar_type),
669 (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
670 (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
671 (Datum::Array(array), ReprScalarType::Array(t)) => {
672 for e in array.elements().iter() {
673 if let Datum::Null = e {
674 continue;
675 }
676
677 difference_with_scalar_type(&e, t)
678 .map_err(|e| element_type_difference("array", e))?;
679 }
680 Ok(())
681 }
682 (Datum::Array(array), ReprScalarType::Int2Vector) => {
683 if array.dims().len() != 1 {
684 return Err(DatumTypeDifference::MismatchDimensions {
685 ctor: "int2vector".to_string(),
686 got: array.dims().len(),
687 expected: 1,
688 });
689 }
690
691 for e in array.elements().iter() {
692 difference_with_scalar_type(&e, &ReprScalarType::Int16)
693 .map_err(|e| element_type_difference("int2vector", e))?;
694 }
695
696 Ok(())
697 }
698 (Datum::Array(_), _) => mismatch(datum, scalar_type),
699 (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
700 for e in list.iter() {
701 if let Datum::Null = e {
702 continue;
703 }
704
705 difference_with_scalar_type(&e, element_type)
706 .map_err(|e| element_type_difference("list", e))?;
707 }
708 Ok(())
709 }
710 (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
711 let len = list.iter().count();
712 if len != fields.len() {
713 return Err(DatumTypeDifference::MismatchDimensions {
714 ctor: "record".to_string(),
715 got: len,
716 expected: fields.len(),
717 });
718 }
719
720 for (e, t) in list.iter().zip_eq(fields) {
721 if let Datum::Null = e {
722 if t.nullable {
723 continue;
724 } else {
725 return Err(DatumTypeDifference::Null {
726 expected: t.scalar_type.clone(),
727 });
728 }
729 }
730
731 difference_with_scalar_type(&e, &t.scalar_type)
732 .map_err(|e| element_type_difference("record", e))?;
733 }
734 Ok(())
735 }
736 (Datum::List(_), _) => mismatch(datum, scalar_type),
737 (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
738 for (_, v) in map.iter() {
739 if let Datum::Null = v {
740 continue;
741 }
742
743 difference_with_scalar_type(&v, value_type)
744 .map_err(|e| element_type_difference("map", e))?;
745 }
746 Ok(())
747 }
748 (Datum::Map(_), _) => mismatch(datum, scalar_type),
749 (Datum::JsonNull, _) => mismatch(datum, scalar_type),
750 (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
751 (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
752 (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
753 (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
754 (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
755 match inner {
756 None => Ok(()),
757 Some(inner) => {
758 if let Some(b) = inner.lower.bound {
759 difference_with_scalar_type(&b.datum(), element_type)
760 .map_err(|e| element_type_difference("range", e))?;
761 }
762 if let Some(b) = inner.upper.bound {
763 difference_with_scalar_type(&b.datum(), element_type)
764 .map_err(|e| element_type_difference("range", e))?;
765 }
766 Ok(())
767 }
768 }
769 }
770 (Datum::Range(_), _) => mismatch(datum, scalar_type),
771 (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
772 (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
773 (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
774 (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
775 }
776 }
777 }
778 if column_type.nullable {
779 if let Datum::Null = datum {
780 return Ok(());
781 }
782 }
783 difference_with_scalar_type(datum, &column_type.scalar_type)
784}
785
786fn row_difference_with_column_types<'a>(
787 source: &'a MirRelationExpr,
788 datums: &Vec<Datum<'_>>,
789 column_types: &[ReprColumnType],
790) -> Result<(), TypeError<'a>> {
791 if datums.len() != column_types.len() {
793 return Err(TypeError::BadConstantRowLen {
794 source,
795 got: datums.len(),
796 expected: column_types.to_vec(),
797 });
798 }
799
800 let mut mismatches = Vec::new();
802 for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
803 if let Err(e) = datum_difference_with_column_type(d, ty) {
804 mismatches.push((i, e));
805 }
806 }
807 if !mismatches.is_empty() {
808 return Err(TypeError::BadConstantRow {
809 source,
810 mismatches,
811 expected: column_types.to_vec(),
812 });
813 }
814
815 Ok(())
816}
817#[derive(Debug)]
819pub struct Typecheck {
820 ctx: SharedContext,
822 disallow_new_globals: bool,
824 strict_join_equivalences: bool,
826 disallow_dummy: bool,
828 recursion_guard: RecursionGuard,
830}
831
832impl CheckedRecursion for Typecheck {
833 fn recursion_guard(&self) -> &RecursionGuard {
834 &self.recursion_guard
835 }
836}
837
838impl Typecheck {
839 pub fn new(ctx: SharedContext) -> Self {
841 Self {
842 ctx,
843 disallow_new_globals: false,
844 strict_join_equivalences: false,
845 disallow_dummy: false,
846 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
847 }
848 }
849
850 pub fn disallow_new_globals(mut self) -> Self {
854 self.disallow_new_globals = true;
855 self
856 }
857
858 pub fn strict_join_equivalences(mut self) -> Self {
862 self.strict_join_equivalences = true;
863
864 self
865 }
866
867 pub fn disallow_dummy(mut self) -> Self {
869 self.disallow_dummy = true;
870 self
871 }
872
873 pub fn typecheck<'a>(
884 &self,
885 expr: &'a MirRelationExpr,
886 ctx: &Context,
887 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
888 use MirRelationExpr::*;
889
890 self.checked_recur(|tc| match expr {
891 Constant { typ, rows } => {
892 if let Ok(rows) = rows {
893 for (row, _id) in rows {
894 let datums = row.unpack();
895
896 row_difference_with_column_types(expr, &datums, &typ.column_types.iter().map(ReprColumnType::from).collect_vec())?;
897
898 if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
899 return Err(TypeError::DisallowedDummy {
900 source: expr,
901 });
902 }
903 }
904 }
905
906 Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
907 }
908 Get { typ, id, .. } => {
909 if let Id::Global(_global_id) = id {
910 if !ctx.contains_key(id) {
911 return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
913 }
914 }
915
916 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
917 source: expr,
918 id: id.clone(),
919 typ: ReprRelationType::from(typ),
920 })?;
921
922 let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
923
924 let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
926
927 if !diffs.is_empty() {
928 return Err(TypeError::MismatchColumns {
929 source: expr,
930 got: column_types,
931 expected: ctx_typ.clone(),
932 diffs,
933 message: "annotation did not match context type".to_string(),
934 });
935 }
936
937 Ok(column_types)
938 }
939 Project { input, outputs } => {
940 let t_in = tc.typecheck(input, ctx)?;
941
942 for x in outputs {
943 if *x >= t_in.len() {
944 return Err(TypeError::BadProject {
945 source: expr,
946 got: outputs.clone(),
947 input_type: t_in,
948 });
949 }
950 }
951
952 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
953 }
954 Map { input, scalars } => {
955 let mut t_in = tc.typecheck(input, ctx)?;
956
957 for scalar_expr in scalars.iter() {
958 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
959
960 if self.disallow_dummy && scalar_expr.contains_dummy() {
961 return Err(TypeError::DisallowedDummy {
962 source: expr,
963 });
964 }
965 }
966
967 Ok(t_in)
968 }
969 FlatMap { input, func, exprs } => {
970 let mut t_in = tc.typecheck(input, ctx)?;
971
972 let mut t_exprs = Vec::with_capacity(exprs.len());
973 for scalar_expr in exprs {
974 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
975
976 if self.disallow_dummy && scalar_expr.contains_dummy() {
977 return Err(TypeError::DisallowedDummy {
978 source: expr,
979 });
980 }
981 }
982 let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
985
986 t_in.extend(t_out);
988 Ok(t_in)
989 }
990 Filter { input, predicates } => {
991 let mut t_in = tc.typecheck(input, ctx)?;
992
993 for column in non_nullable_columns(predicates) {
996 t_in[column].nullable = false;
997 }
998
999 for scalar_expr in predicates {
1000 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1001
1002 if t.scalar_type != ReprScalarType::Bool {
1006 let sub = t.scalar_type.clone();
1007
1008 return Err(TypeError::MismatchColumn {
1009 source: expr,
1010 got: t,
1011 expected: ReprColumnType {
1012 scalar_type: ReprScalarType::Bool,
1013 nullable: true,
1014 },
1015 diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
1016 message: "expected boolean condition".to_string(),
1017 });
1018 }
1019
1020 if self.disallow_dummy && scalar_expr.contains_dummy() {
1021 return Err(TypeError::DisallowedDummy {
1022 source: expr,
1023 });
1024 }
1025 }
1026
1027 Ok(t_in)
1028 }
1029 Join {
1030 inputs,
1031 equivalences,
1032 implementation,
1033 } => {
1034 let mut t_in_global = Vec::new();
1035 let mut t_in_local = vec![Vec::new(); inputs.len()];
1036
1037 for (i, input) in inputs.iter().enumerate() {
1038 let input_t = tc.typecheck(input, ctx)?;
1039 t_in_global.extend(input_t.clone());
1040 t_in_local[i] = input_t;
1041 }
1042
1043 for eq_class in equivalences {
1044 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1045
1046 let mut all_nullable = true;
1047
1048 for scalar_expr in eq_class {
1049 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1051
1052 if !t_expr.nullable {
1053 all_nullable = false;
1054 }
1055
1056 if let Some(t_first) = t_exprs.get(0) {
1057 let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
1058 if !diffs.is_empty() {
1059 return Err(TypeError::MismatchColumn {
1060 source: expr,
1061 got: t_expr,
1062 expected: t_first.clone(),
1063 diffs,
1064 message: "equivalence class members have different scalar types".to_string(),
1065 });
1066 }
1067
1068 if self.strict_join_equivalences {
1072 if t_expr.nullable != t_first.nullable {
1073 let sub = t_expr.clone();
1074 let sup = t_first.clone();
1075
1076 let err = TypeError::MismatchColumn {
1077 source: expr,
1078 got: t_expr.clone(),
1079 expected: t_first.clone(),
1080 diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
1081 message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
1082 };
1083
1084 ::tracing::debug!("{err}");
1086 }
1087 }
1088 }
1089
1090 if self.disallow_dummy && scalar_expr.contains_dummy() {
1091 return Err(TypeError::DisallowedDummy {
1092 source: expr,
1093 });
1094 }
1095
1096 t_exprs.push(t_expr);
1097 }
1098
1099 if self.strict_join_equivalences && all_nullable {
1100 let err = TypeError::BadJoinEquivalence {
1101 source: expr,
1102 got: t_exprs,
1103 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1104 };
1105
1106 ::tracing::debug!("{err}");
1108 }
1109 }
1110
1111 match implementation {
1113 JoinImplementation::Differential((start_idx, first_key, _), others) => {
1114 if let Some(key) = first_key {
1115 for k in key {
1116 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1117 }
1118 }
1119
1120 for (idx, key, _) in others {
1121 for k in key {
1122 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1123 }
1124 }
1125 }
1126 JoinImplementation::DeltaQuery(plans) => {
1127 for plan in plans {
1128 for (idx, key, _) in plan {
1129 for k in key {
1130 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1131 }
1132 }
1133 }
1134 }
1135 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1136 let typ: Vec<ReprColumnType> = key
1137 .iter()
1138 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1139 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1140
1141 for row in consts {
1142 let datums = row.unpack();
1143
1144 row_difference_with_column_types(expr, &datums, &typ)?;
1145 }
1146 }
1147 JoinImplementation::Unimplemented => (),
1148 }
1149
1150 Ok(t_in_global)
1151 }
1152 Reduce {
1153 input,
1154 group_key,
1155 aggregates,
1156 monotonic: _,
1157 expected_group_size: _,
1158 } => {
1159 let t_in = tc.typecheck(input, ctx)?;
1160
1161 let mut t_out = group_key
1162 .iter()
1163 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1164 .collect::<Result<Vec<_>, _>>()?;
1165
1166 if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
1167 return Err(TypeError::DisallowedDummy {
1168 source: expr,
1169 });
1170 }
1171
1172 for agg in aggregates {
1173 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1174 }
1175
1176 Ok(t_out)
1177 }
1178 TopK {
1179 input,
1180 group_key,
1181 order_key,
1182 limit: _,
1183 offset: _,
1184 monotonic: _,
1185 expected_group_size: _,
1186 } => {
1187 let t_in = tc.typecheck(input, ctx)?;
1188
1189 for &k in group_key {
1190 if k >= t_in.len() {
1191 return Err(TypeError::BadTopKGroupKey {
1192 source: expr,
1193 k,
1194 input_type: t_in,
1195 });
1196 }
1197 }
1198
1199 for order in order_key {
1200 if order.column >= t_in.len() {
1201 return Err(TypeError::BadTopKOrdering {
1202 source: expr,
1203 order: order.clone(),
1204 input_type: t_in,
1205 });
1206 }
1207 }
1208
1209 Ok(t_in)
1210 }
1211 Negate { input } => tc.typecheck(input, ctx),
1212 Threshold { input } => tc.typecheck(input, ctx),
1213 Union { base, inputs } => {
1214 let mut t_base = tc.typecheck(base, ctx)?;
1215
1216 for input in inputs {
1217 let t_input = tc.typecheck(input, ctx)?;
1218
1219 let len_sub = t_base.len();
1220 let len_sup = t_input.len();
1221 if len_sub != len_sup {
1222 return Err(TypeError::MismatchColumns {
1223 source: expr,
1224 got: t_base.clone(),
1225 expected: t_input,
1226 diffs: vec![ReprRelationTypeDifference::Length {
1227 len_sub,
1228 len_sup,
1229 }],
1230 message: "Union branches have different numbers of columns".to_string(),
1231 });
1232 }
1233
1234 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1235 let diffs = column_union(base_col, &input_col);
1236 if !diffs.is_empty() {
1237 return Err(TypeError::MismatchColumn {
1238 source: expr,
1239 got: input_col,
1240 expected: base_col.clone(),
1241 diffs,
1242 message:
1243 "couldn't compute union of column types in Union"
1244 .to_string(),
1245 });
1246 }
1247
1248 }
1249 }
1250
1251 Ok(t_base)
1252 }
1253 Let { id, value, body } => {
1254 let t_value = tc.typecheck(value, ctx)?;
1255
1256 let binding = Id::Local(*id);
1257 if ctx.contains_key(&binding) {
1258 return Err(TypeError::Shadowing {
1259 source: expr,
1260 id: binding,
1261 });
1262 }
1263
1264 let mut body_ctx = ctx.clone();
1265 body_ctx.insert(Id::Local(*id), t_value);
1266
1267 tc.typecheck(body, &body_ctx)
1268 }
1269 LetRec { ids, values, body, limits: _ } => {
1270 if ids.len() != values.len() {
1271 return Err(TypeError::BadLetRecBindings { source: expr });
1272 }
1273
1274 let mut ctx = ctx.clone();
1277 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1279 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1280 }
1281
1282 for (id, value) in ids.iter().zip_eq(values.iter()) {
1283 let typ = tc.typecheck(value, &ctx)?;
1284
1285 let id = Id::Local(id.clone());
1286 if let Some(ctx_typ) = ctx.get_mut(&id) {
1287 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1288 let diffs = column_union(base_col, &input_col);
1290 if !diffs.is_empty() {
1291 return Err(TypeError::MismatchColumn {
1292 source: expr,
1293 got: input_col,
1294 expected: base_col.clone(),
1295 diffs,
1296 message:
1297 "couldn't compute union of column types in LetRec"
1298 .to_string(),
1299 })
1300 }
1301 }
1302 } else {
1303 ctx.insert(id, typ);
1305 }
1306 }
1307
1308 tc.typecheck(body, &ctx)
1309 }
1310 ArrangeBy { input, keys } => {
1311 let t_in = tc.typecheck(input, ctx)?;
1312
1313 for key in keys {
1314 for k in key {
1315 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1316 }
1317 }
1318
1319 Ok(t_in)
1320 }
1321 })
1322 }
1323
1324 fn collect_recursive_variable_types<'a>(
1328 &self,
1329 expr: &'a MirRelationExpr,
1330 ids: &[LocalId],
1331 ctx: &mut Context,
1332 ) -> Result<(), TypeError<'a>> {
1333 use MirRelationExpr::*;
1334
1335 self.checked_recur(|tc| {
1336 match expr {
1337 Get {
1338 id: Id::Local(id),
1339 typ,
1340 ..
1341 } => {
1342 if !ids.contains(id) {
1343 return Ok(());
1344 }
1345
1346 let id = Id::Local(id.clone());
1347 if let Some(ctx_typ) = ctx.get_mut(&id) {
1348 let typ = typ
1349 .column_types
1350 .iter()
1351 .map(ReprColumnType::from)
1352 .collect_vec();
1353
1354 if ctx_typ.len() != typ.len() {
1355 let diffs = relation_subtype_difference(&typ, ctx_typ);
1356
1357 return Err(TypeError::MismatchColumns {
1358 source: expr,
1359 got: typ,
1360 expected: ctx_typ.clone(),
1361 diffs,
1362 message: "environment and type annotation did not match"
1363 .to_string(),
1364 });
1365 }
1366
1367 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1368 let diffs = column_union(base_col, &input_col);
1369 if !diffs.is_empty() {
1370 return Err(TypeError::MismatchColumn {
1371 source: expr,
1372 got: input_col,
1373 expected: base_col.clone(),
1374 diffs,
1375 message:
1376 "couldn't compute union of column types in Get and context"
1377 .to_string(),
1378 });
1379 }
1380 }
1381 } else {
1382 ctx.insert(
1383 id,
1384 typ.column_types
1385 .iter()
1386 .map(ReprColumnType::from)
1387 .collect_vec(),
1388 );
1389 }
1390 }
1391 Get {
1392 id: Id::Global(..), ..
1393 }
1394 | Constant { .. } => (),
1395 Let { id, value, body } => {
1396 tc.collect_recursive_variable_types(value, ids, ctx)?;
1397
1398 if ids.contains(id) {
1400 return Err(TypeError::Shadowing {
1401 source: expr,
1402 id: Id::Local(*id),
1403 });
1404 }
1405
1406 tc.collect_recursive_variable_types(body, ids, ctx)?;
1407 }
1408 LetRec {
1409 ids: inner_ids,
1410 values,
1411 body,
1412 limits: _,
1413 } => {
1414 for inner_id in inner_ids {
1415 if ids.contains(inner_id) {
1416 return Err(TypeError::Shadowing {
1417 source: expr,
1418 id: Id::Local(*inner_id),
1419 });
1420 }
1421 }
1422
1423 for value in values {
1424 tc.collect_recursive_variable_types(value, ids, ctx)?;
1425 }
1426
1427 tc.collect_recursive_variable_types(body, ids, ctx)?;
1428 }
1429 Project { input, .. }
1430 | Map { input, .. }
1431 | FlatMap { input, .. }
1432 | Filter { input, .. }
1433 | Reduce { input, .. }
1434 | TopK { input, .. }
1435 | Negate { input }
1436 | Threshold { input }
1437 | ArrangeBy { input, .. } => {
1438 tc.collect_recursive_variable_types(input, ids, ctx)?;
1439 }
1440 Join { inputs, .. } => {
1441 for input in inputs {
1442 tc.collect_recursive_variable_types(input, ids, ctx)?;
1443 }
1444 }
1445 Union { base, inputs } => {
1446 tc.collect_recursive_variable_types(base, ids, ctx)?;
1447
1448 for input in inputs {
1449 tc.collect_recursive_variable_types(input, ids, ctx)?;
1450 }
1451 }
1452 }
1453
1454 Ok(())
1455 })
1456 }
1457
1458 fn typecheck_scalar<'a>(
1459 &self,
1460 expr: &'a MirScalarExpr,
1461 source: &'a MirRelationExpr,
1462 column_types: &[ReprColumnType],
1463 ) -> Result<ReprColumnType, TypeError<'a>> {
1464 use MirScalarExpr::*;
1465
1466 self.checked_recur(|tc| match expr {
1467 Column(i, _) => match column_types.get(*i) {
1468 Some(ty) => Ok(ty.clone()),
1469 None => Err(TypeError::NoSuchColumn {
1470 source,
1471 expr,
1472 col: *i,
1473 }),
1474 },
1475 Literal(row, typ) => {
1476 let typ = ReprColumnType::from(typ);
1477 if let Ok(row) = row {
1478 let datums = row.unpack();
1479
1480 row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1481 }
1482
1483 Ok(typ)
1484 }
1485 CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1486 CallUnary { expr, func } => {
1487 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1488 let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1489 Ok(ReprColumnType::from(&typ_out))
1490 }
1491 CallBinary { expr1, expr2, func } => {
1492 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1493 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1494 let typ_out = func.output_type(
1495 SqlColumnType::from_repr(&typ_in1),
1496 SqlColumnType::from_repr(&typ_in2),
1497 );
1498 Ok(ReprColumnType::from(&typ_out))
1499 }
1500 CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1501 &func.output_type(
1502 exprs
1503 .iter()
1504 .map(|e| {
1505 tc.typecheck_scalar(e, source, column_types)
1506 .map(|typ| SqlColumnType::from_repr(&typ))
1507 })
1508 .collect::<Result<Vec<_>, TypeError>>()?,
1509 ),
1510 )),
1511 If { cond, then, els } => {
1512 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1513
1514 if cond_type.scalar_type != ReprScalarType::Bool {
1518 let sub = cond_type.scalar_type.clone();
1519
1520 return Err(TypeError::MismatchColumn {
1521 source,
1522 got: cond_type,
1523 expected: ReprColumnType {
1524 scalar_type: ReprScalarType::Bool,
1525 nullable: true,
1526 },
1527 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1528 sub,
1529 sup: ReprScalarType::Bool,
1530 }],
1531 message: "expected boolean condition".to_string(),
1532 });
1533 }
1534
1535 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1536 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1537
1538 let diffs = column_union(&mut then_type, &else_type);
1539 if !diffs.is_empty() {
1540 return Err(TypeError::MismatchColumn {
1541 source,
1542 got: then_type,
1543 expected: else_type,
1544 diffs,
1545 message: "couldn't compute union of column types for If".to_string(),
1546 });
1547 }
1548
1549 Ok(then_type)
1550 }
1551 })
1552 }
1553
1554 pub fn typecheck_aggregate<'a>(
1556 &self,
1557 expr: &'a AggregateExpr,
1558 source: &'a MirRelationExpr,
1559 column_types: &[ReprColumnType],
1560 ) -> Result<ReprColumnType, TypeError<'a>> {
1561 self.checked_recur(|tc| {
1562 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1563
1564 Ok(ReprColumnType::from(
1567 &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1568 ))
1569 })
1570 }
1571}
1572
1573macro_rules! type_error {
1577 ($severity:expr, $($arg:tt)+) => {{
1578 if $severity {
1579 soft_panic_or_log!($($arg)+);
1580 } else {
1581 ::tracing::debug!($($arg)+);
1582 }
1583 }}
1584}
1585
1586impl crate::Transform for Typecheck {
1587 fn name(&self) -> &'static str {
1588 "Typecheck"
1589 }
1590
1591 fn actually_perform_transform(
1592 &self,
1593 relation: &mut MirRelationExpr,
1594 transform_ctx: &mut crate::TransformCtx,
1595 ) -> Result<(), crate::TransformError> {
1596 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1597
1598 let expected = transform_ctx
1599 .global_id
1600 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1601
1602 if let Some(id) = transform_ctx.global_id {
1603 if self.disallow_new_globals
1604 && expected.is_none()
1605 && transform_ctx.global_id.is_some()
1606 && !id.is_transient()
1607 {
1608 type_error!(
1609 false, "type warning: new non-transient global id {id}\n{}",
1611 relation.pretty()
1612 );
1613 }
1614 }
1615
1616 let got = self.typecheck(relation, &typecheck_ctx);
1617
1618 let humanizer = mz_repr::explain::DummyHumanizer;
1619
1620 match (got, expected) {
1621 (Ok(got), Some(expected)) => {
1622 let id = transform_ctx.global_id.unwrap();
1623
1624 let diffs = relation_subtype_difference(expected, &got);
1626 if !diffs.is_empty() {
1627 let severity = diffs
1629 .iter()
1630 .any(|diff| diff.clone().ignore_nullability().is_some());
1631
1632 let err = TypeError::MismatchColumns {
1633 source: relation,
1634 got,
1635 expected: expected.clone(),
1636 diffs,
1637 message: format!(
1638 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1639 ),
1640 };
1641
1642 type_error!(severity, "type error in known global id {id}:\n{err}");
1643 }
1644 }
1645 (Ok(got), None) => {
1646 if let Some(id) = transform_ctx.global_id {
1647 typecheck_ctx.insert(Id::Global(id), got);
1648 }
1649 }
1650 (Err(err), _) => {
1651 let (expected, binding) = match expected {
1652 Some(expected) => {
1653 let id = transform_ctx.global_id.unwrap();
1654 (
1655 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1656 format!("known global id {id}"),
1657 )
1658 }
1659 None => ("".to_string(), "transient query".to_string()),
1660 };
1661
1662 type_error!(
1663 true, "type error in {binding}:\n{err}\n{expected}{}",
1665 relation.pretty()
1666 );
1667 }
1668 }
1669
1670 Ok(())
1671 }
1672}
1673
1674pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1676where
1677 H: ExprHumanizer,
1678{
1679 let mut s = String::with_capacity(2 + 3 * cols.len());
1680
1681 s.push('(');
1682
1683 let mut it = cols.iter().peekable();
1684 while let Some(col) = it.next() {
1685 s.push_str(&humanizer.humanize_column_type_repr(col, false));
1686
1687 if it.peek().is_some() {
1688 s.push_str(", ");
1689 }
1690 }
1691
1692 s.push(')');
1693
1694 s
1695}
1696
1697impl ReprRelationTypeDifference {
1698 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1702 where
1703 H: ExprHumanizer,
1704 {
1705 use ReprRelationTypeDifference::*;
1706 match self {
1707 Length { len_sub, len_sup } => {
1708 writeln!(
1709 f,
1710 " number of columns do not match ({len_sub} != {len_sup})"
1711 )
1712 }
1713 Column { col, diff } => {
1714 writeln!(f, " column {col} differs:")?;
1715 diff.humanize(4, h, f)
1716 }
1717 }
1718 }
1719}
1720
1721impl ReprColumnTypeDifference {
1722 pub fn humanize<H>(
1724 &self,
1725 indent: usize,
1726 h: &H,
1727 f: &mut std::fmt::Formatter<'_>,
1728 ) -> std::fmt::Result
1729 where
1730 H: ExprHumanizer,
1731 {
1732 use ReprColumnTypeDifference::*;
1733
1734 write!(f, "{:indent$}", "")?;
1736
1737 match self {
1738 NotSubtype { sub, sup } => {
1739 let sub = h.humanize_scalar_type_repr(sub, false);
1740 let sup = h.humanize_scalar_type_repr(sup, false);
1741
1742 writeln!(f, "{sub} is a not a subtype of {sup}")
1743 }
1744 Nullability { sub, sup } => {
1745 let sub = h.humanize_column_type_repr(sub, false);
1746 let sup = h.humanize_column_type_repr(sup, false);
1747
1748 writeln!(f, "{sub} is nullable but {sup} is not")
1749 }
1750 ElementType { ctor, element_type } => {
1751 writeln!(f, "{ctor} element types differ:")?;
1752
1753 element_type.humanize(indent + 2, h, f)
1754 }
1755 RecordMissingFields { missing } => {
1756 write!(f, "missing column fields:")?;
1757 for col in missing {
1758 write!(f, " {col}")?;
1759 }
1760 f.write_char('\n')
1761 }
1762 RecordFields { fields } => {
1763 writeln!(f, "{} record fields differ:", fields.len())?;
1764
1765 for (i, diff) in fields.iter().enumerate() {
1766 writeln!(f, "{:indent$} field {i}:", "")?;
1767 diff.humanize(indent + 4, h, f)?;
1768 }
1769 Ok(())
1770 }
1771 }
1772 }
1773}
1774
1775impl DatumTypeDifference {
1776 pub fn humanize<H>(
1778 &self,
1779 indent: usize,
1780 h: &H,
1781 f: &mut std::fmt::Formatter<'_>,
1782 ) -> std::fmt::Result
1783 where
1784 H: ExprHumanizer,
1785 {
1786 write!(f, "{:indent$}", "")?;
1788
1789 match self {
1790 DatumTypeDifference::Null { expected } => {
1791 let expected = h.humanize_scalar_type_repr(expected, false);
1792 writeln!(
1793 f,
1794 "unexpected null, expected representation type {expected}"
1795 )?
1796 }
1797 DatumTypeDifference::Mismatch {
1798 got_debug,
1799 expected,
1800 } => {
1801 let expected = h.humanize_scalar_type_repr(expected, false);
1802 writeln!(
1804 f,
1805 "got datum {got_debug}, expected representation type {expected}"
1806 )?;
1807 }
1808 DatumTypeDifference::MismatchDimensions {
1809 ctor,
1810 got,
1811 expected,
1812 } => {
1813 writeln!(
1814 f,
1815 "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1816 )?;
1817 }
1818 DatumTypeDifference::ElementType { ctor, element_type } => {
1819 writeln!(f, "{ctor} element types differ:")?;
1820 element_type.humanize(indent + 4, h, f)?;
1821 }
1822 }
1823
1824 Ok(())
1825 }
1826}
1827
1828#[allow(missing_debug_implementations)]
1830pub struct TypeErrorHumanizer<'a, 'b, H>
1831where
1832 H: ExprHumanizer,
1833{
1834 err: &'a TypeError<'a>,
1835 humanizer: &'b H,
1836}
1837
1838impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1839where
1840 H: ExprHumanizer,
1841{
1842 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1844 Self { err, humanizer }
1845 }
1846}
1847
1848impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1849where
1850 H: ExprHumanizer,
1851{
1852 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1853 self.err.humanize(self.humanizer, f)
1854 }
1855}
1856
1857impl<'a> std::fmt::Display for TypeError<'a> {
1858 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1859 TypeErrorHumanizer {
1860 err: self,
1861 humanizer: &DummyHumanizer,
1862 }
1863 .fmt(f)
1864 }
1865}
1866
1867impl<'a> TypeError<'a> {
1868 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1870 use TypeError::*;
1871 match self {
1872 Unbound { source, .. }
1873 | NoSuchColumn { source, .. }
1874 | MismatchColumn { source, .. }
1875 | MismatchColumns { source, .. }
1876 | BadConstantRowLen { source, .. }
1877 | BadConstantRow { source, .. }
1878 | BadProject { source, .. }
1879 | BadJoinEquivalence { source, .. }
1880 | BadTopKGroupKey { source, .. }
1881 | BadTopKOrdering { source, .. }
1882 | BadLetRecBindings { source }
1883 | Shadowing { source, .. }
1884 | DisallowedDummy { source, .. } => Some(source),
1885 Recursion { .. } => None,
1886 }
1887 }
1888
1889 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1890 where
1891 H: ExprHumanizer,
1892 {
1893 if let Some(source) = self.source() {
1894 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1895 }
1896
1897 use TypeError::*;
1898 match self {
1899 Unbound { source: _, id, typ } => {
1900 let typ = columns_pretty(&typ.column_types, humanizer);
1901 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1902 }
1903 NoSuchColumn {
1904 source: _,
1905 expr,
1906 col,
1907 } => writeln!(f, "{expr} references non-existent column {col}")?,
1908 MismatchColumn {
1909 source: _,
1910 got,
1911 expected,
1912 diffs,
1913 message,
1914 } => {
1915 let got = humanizer.humanize_column_type_repr(got, false);
1916 let expected = humanizer.humanize_column_type_repr(expected, false);
1917 writeln!(
1918 f,
1919 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1920 )?;
1921
1922 for diff in diffs {
1923 diff.humanize(2, humanizer, f)?;
1924 }
1925 }
1926 MismatchColumns {
1927 source: _,
1928 got,
1929 expected,
1930 diffs,
1931 message,
1932 } => {
1933 let got = columns_pretty(got, humanizer);
1934 let expected = columns_pretty(expected, humanizer);
1935
1936 writeln!(
1937 f,
1938 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1939 )?;
1940
1941 for diff in diffs {
1942 diff.humanize(humanizer, f)?;
1943 }
1944 }
1945 BadConstantRowLen {
1946 source: _,
1947 got,
1948 expected,
1949 } => {
1950 let expected = columns_pretty(expected, humanizer);
1951 writeln!(
1952 f,
1953 "bad constant row\n row has length {got}\nexpected row of type {expected}"
1954 )?
1955 }
1956 BadConstantRow {
1957 source: _,
1958 mismatches,
1959 expected,
1960 } => {
1961 let expected = columns_pretty(expected, humanizer);
1962
1963 let num_mismatches = mismatches.len();
1964 let plural = if num_mismatches == 1 { "" } else { "es" };
1965 writeln!(
1966 f,
1967 "bad constant row\n got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1968 )?;
1969
1970 if num_mismatches > 0 {
1971 writeln!(f, "")?;
1972 for (col, diff) in mismatches.iter() {
1973 writeln!(f, " column #{col}:")?;
1974 diff.humanize(8, humanizer, f)?;
1975 }
1976 }
1977 }
1978 BadProject {
1979 source: _,
1980 got,
1981 input_type,
1982 } => {
1983 let input_type = columns_pretty(input_type, humanizer);
1984
1985 writeln!(
1986 f,
1987 "projection of non-existant columns {got:?} from type {input_type}"
1988 )?
1989 }
1990 BadJoinEquivalence {
1991 source: _,
1992 got,
1993 message,
1994 } => {
1995 let got = columns_pretty(got, humanizer);
1996
1997 writeln!(f, "bad join equivalence {got}: {message}")?
1998 }
1999 BadTopKGroupKey {
2000 source: _,
2001 k,
2002 input_type,
2003 } => {
2004 let input_type = columns_pretty(input_type, humanizer);
2005
2006 writeln!(
2007 f,
2008 "TopK group key component references invalid column {k} in columns: {input_type}"
2009 )?
2010 }
2011 BadTopKOrdering {
2012 source: _,
2013 order,
2014 input_type,
2015 } => {
2016 let col = order.column;
2017 let num_cols = input_type.len();
2018 let are = if num_cols == 1 { "is" } else { "are" };
2019 let s = if num_cols == 1 { "" } else { "s" };
2020 let input_type = columns_pretty(input_type, humanizer);
2021
2022 let mode = HumanizedExplain::new(false);
2024 let order = mode.expr(order, None);
2025
2026 writeln!(
2027 f,
2028 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2029 )?
2030 }
2031 BadLetRecBindings { source: _ } => {
2032 writeln!(f, "LetRec ids and definitions don't line up")?
2033 }
2034 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2035 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2036 Recursion { error } => writeln!(f, "{error}")?,
2037 }
2038
2039 Ok(())
2040 }
2041}
2042
2043#[cfg(test)]
2044mod tests {
2045 use mz_ore::{assert_err, assert_ok};
2046 use mz_repr::{arb_datum, arb_datum_for_column};
2047 use proptest::prelude::*;
2048
2049 use super::*;
2050
2051 #[mz_ore::test]
2052 fn test_datum_type_difference() {
2053 let datum = Datum::Int16(1);
2054
2055 assert_ok!(datum_difference_with_column_type(
2056 &datum,
2057 &ReprColumnType {
2058 scalar_type: ReprScalarType::Int16,
2059 nullable: true,
2060 }
2061 ));
2062
2063 assert_err!(datum_difference_with_column_type(
2064 &datum,
2065 &ReprColumnType {
2066 scalar_type: ReprScalarType::Int32,
2067 nullable: false,
2068 }
2069 ));
2070 }
2071
2072 proptest! {
2073 #![proptest_config(ProptestConfig { cases: 5000, max_global_rejects: 2500, ..Default::default() })]
2074 #[mz_ore::test]
2075 #[cfg_attr(miri, ignore)]
2076 fn datum_type_difference_with_instance_of_on_valid_data((src, datum) in any::<SqlColumnType>().prop_flat_map(|src| {
2077 let datum = arb_datum_for_column(src.clone());
2078 (Just(src), datum) }
2079 )) {
2080 let typ = ReprColumnType::from(&src);
2081 let datum = Datum::from(&datum);
2082
2083 if datum.contains_dummy() {
2084 return Err(TestCaseError::reject("datum contains a dummy"));
2085 }
2086
2087 let diff = datum_difference_with_column_type(&datum, &typ);
2088 if datum.is_instance_of(&typ) {
2089 assert_ok!(diff);
2090 } else {
2091 assert_err!(diff);
2092 }
2093 }
2094 }
2095
2096 proptest! {
2097 #![proptest_config(ProptestConfig::with_cases(10000))]
2100 #[mz_ore::test]
2101 #[cfg_attr(miri, ignore)]
2102 fn datum_type_difference_agrees_with_is_instance_of_on_random_data(src in any::<SqlColumnType>(), datum in arb_datum(false)) {
2103 let typ = ReprColumnType::from(&src);
2104 let datum = Datum::from(&datum);
2105
2106 assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2107
2108 let diff = datum_difference_with_column_type(&datum, &typ);
2109 if datum.is_instance_of(&typ) {
2110 assert_ok!(diff);
2111 } else {
2112 assert_err!(diff);
2113 }
2114 }
2115 }
2116
2117 #[mz_ore::test]
2118 fn datum_type_difference_github_10039() {
2119 let typ = ReprColumnType {
2120 scalar_type: ReprScalarType::Record {
2121 fields: Box::new([ReprColumnType {
2122 scalar_type: ReprScalarType::UInt32,
2123 nullable: false,
2124 }]),
2125 },
2126 nullable: false,
2127 };
2128
2129 let mut row = mz_repr::Row::default();
2130 row.packer()
2131 .push_list(std::iter::once(mz_repr::Datum::Null));
2132 let datum = row.unpack_first();
2133
2134 assert!(!datum.is_instance_of(&typ));
2135 let diff = datum_difference_with_column_type(&datum, &typ);
2136 assert_err!(diff);
2137 }
2138}