1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27 ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28 SqlColumnType,
29};
30
31pub type SharedContext = Arc<Mutex<Context>>;
36
37pub fn empty_context() -> SharedContext {
39 Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48 Unbound {
50 source: &'a MirRelationExpr,
52 id: Id,
54 typ: ReprRelationType,
56 },
57 NoSuchColumn {
59 source: &'a MirRelationExpr,
61 expr: &'a MirScalarExpr,
63 col: usize,
65 },
66 MismatchColumn {
68 source: &'a MirRelationExpr,
70 got: ReprColumnType,
72 expected: ReprColumnType,
74 diffs: Vec<ReprColumnTypeDifference>,
76 message: String,
78 },
79 MismatchColumns {
81 source: &'a MirRelationExpr,
83 got: Vec<ReprColumnType>,
85 expected: Vec<ReprColumnType>,
87 diffs: Vec<ReprRelationTypeDifference>,
89 message: String,
91 },
92 BadConstantRowLen {
94 source: &'a MirRelationExpr,
96 got: usize,
98 expected: Vec<ReprColumnType>,
100 },
101 BadConstantRow {
103 source: &'a MirRelationExpr,
105 mismatches: Vec<(usize, DatumTypeDifference)>,
107 expected: Vec<ReprColumnType>,
109 },
112 BadProject {
114 source: &'a MirRelationExpr,
116 got: Vec<usize>,
118 input_type: Vec<ReprColumnType>,
120 },
121 BadJoinEquivalence {
123 source: &'a MirRelationExpr,
125 got: Vec<ReprColumnType>,
127 message: String,
129 },
130 BadTopKGroupKey {
132 source: &'a MirRelationExpr,
134 k: usize,
136 input_type: Vec<ReprColumnType>,
138 },
139 BadTopKOrdering {
141 source: &'a MirRelationExpr,
143 order: ColumnOrder,
145 input_type: Vec<ReprColumnType>,
147 },
148 BadLetRecBindings {
150 source: &'a MirRelationExpr,
152 },
153 Shadowing {
155 source: &'a MirRelationExpr,
157 id: Id,
159 },
160 Recursion {
162 error: RecursionLimitError,
164 },
165 DisallowedDummy {
167 source: &'a MirRelationExpr,
169 },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173 fn from(error: RecursionLimitError) -> Self {
174 TypeError::Recursion { error }
175 }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185 Length {
187 len_sub: usize,
189 len_sup: usize,
191 },
192 Column {
194 col: usize,
196 diff: ReprColumnTypeDifference,
198 },
199}
200
201#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207 NotSubtype {
209 sub: ReprScalarType,
211 sup: ReprScalarType,
213 },
214 Nullability {
216 sub: ReprColumnType,
218 sup: ReprColumnType,
220 },
221 ElementType {
223 ctor: String,
225 element_type: Box<ReprColumnTypeDifference>,
227 },
228 RecordMissingFields {
230 missing: Vec<ColumnName>,
232 },
233 RecordFields {
235 fields: Vec<ReprColumnTypeDifference>,
237 },
238}
239
240impl ReprRelationTypeDifference {
241 pub fn ignore_nullability(self) -> Option<Self> {
245 use ReprRelationTypeDifference::*;
246
247 match self {
248 Length { .. } => Some(self),
249 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250 }
251 }
252}
253
254impl ReprColumnTypeDifference {
255 pub fn ignore_nullability(self) -> Option<Self> {
259 use ReprColumnTypeDifference::*;
260
261 match self {
262 Nullability { .. } => None,
263 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264 ElementType { ctor, element_type } => {
265 element_type
266 .ignore_nullability()
267 .map(|element_type| ElementType {
268 ctor,
269 element_type: Box::new(element_type),
270 })
271 }
272 RecordFields { fields } => {
273 let fields = fields
274 .into_iter()
275 .flat_map(|diff| diff.ignore_nullability())
276 .collect::<Vec<_>>();
277
278 if fields.is_empty() {
279 None
280 } else {
281 Some(RecordFields { fields })
282 }
283 }
284 }
285 }
286}
287
288pub fn relation_subtype_difference(
292 sub: &[ReprColumnType],
293 sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295 let mut diffs = Vec::new();
296
297 if sub.len() != sup.len() {
298 diffs.push(ReprRelationTypeDifference::Length {
299 len_sub: sub.len(),
300 len_sup: sup.len(),
301 });
302
303 return diffs;
305 }
306
307 diffs.extend(
308 sub.iter()
309 .zip_eq(sup.iter())
310 .enumerate()
311 .flat_map(|(col, (sub_ty, sup_ty))| {
312 column_subtype_difference(sub_ty, sup_ty)
313 .into_iter()
314 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315 }),
316 );
317
318 diffs
319}
320
321pub fn column_subtype_difference(
325 sub: &ReprColumnType,
326 sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330 if sub.nullable && !sup.nullable {
331 diffs.push(ReprColumnTypeDifference::Nullability {
332 sub: sub.clone(),
333 sup: sup.clone(),
334 });
335 }
336
337 diffs
338}
339
340pub fn scalar_subtype_difference(
344 sub: &ReprScalarType,
345 sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347 use ReprScalarType::*;
348
349 let mut diffs = Vec::new();
350
351 match (sub, sup) {
352 (
353 List {
354 element_type: sub_elt,
355 ..
356 },
357 List {
358 element_type: sup_elt,
359 ..
360 },
361 )
362 | (
363 Map {
364 value_type: sub_elt,
365 ..
366 },
367 Map {
368 value_type: sup_elt,
369 ..
370 },
371 )
372 | (
373 Range {
374 element_type: sub_elt,
375 ..
376 },
377 Range {
378 element_type: sup_elt,
379 ..
380 },
381 )
382 | (Array(sub_elt), Array(sup_elt)) => {
383 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384 diffs.extend(
385 scalar_subtype_difference(sub_elt, sup_elt)
386 .into_iter()
387 .map(|diff| ReprColumnTypeDifference::ElementType {
388 ctor: ctor.clone(),
389 element_type: Box::new(diff),
390 }),
391 );
392 }
393 (
394 Record {
395 fields: sub_fields, ..
396 },
397 Record {
398 fields: sup_fields, ..
399 },
400 ) => {
401 if sub_fields.len() != sup_fields.len() {
402 diffs.push(ReprColumnTypeDifference::NotSubtype {
403 sub: sub.clone(),
404 sup: sup.clone(),
405 });
406 return diffs;
407 }
408
409 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411 }
412 }
413 (_, _) => {
414 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415 diffs.push(ReprColumnTypeDifference::NotSubtype {
416 sub: sub.clone(),
417 sup: sup.clone(),
418 })
419 }
420 }
421 };
422
423 diffs
424}
425
426pub fn scalar_union(
430 typ: &mut ReprScalarType,
431 other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433 use ReprScalarType::*;
434
435 let mut diffs = Vec::new();
436
437 let ctor = ReprScalarBaseType::from(&*typ);
439 match (typ, other) {
440 (
441 List {
442 element_type: typ_elt,
443 },
444 List {
445 element_type: other_elt,
446 },
447 )
448 | (
449 Map {
450 value_type: typ_elt,
451 },
452 Map {
453 value_type: other_elt,
454 },
455 )
456 | (
457 Range {
458 element_type: typ_elt,
459 },
460 Range {
461 element_type: other_elt,
462 },
463 )
464 | (Array(typ_elt), Array(other_elt)) => {
465 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466 diffs.extend(
467 res.into_iter()
468 .map(|diff| ReprColumnTypeDifference::ElementType {
469 ctor: format!("{ctor:?}"),
470 element_type: Box::new(diff),
471 }),
472 );
473 }
474 (
475 Record { fields: typ_fields },
476 Record {
477 fields: other_fields,
478 },
479 ) => {
480 if typ_fields.len() != other_fields.len() {
481 diffs.push(ReprColumnTypeDifference::NotSubtype {
482 sub: ReprScalarType::Record {
483 fields: typ_fields.clone(),
484 },
485 sup: other.clone(),
486 });
487 return diffs;
488 }
489
490 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491 diffs.extend(column_union(typ_ty, other_ty));
492 }
493 }
494 (typ, _) => {
495 if ctor != ReprScalarBaseType::from(other) {
496 diffs.push(ReprColumnTypeDifference::NotSubtype {
497 sub: typ.clone(),
498 sup: other.clone(),
499 })
500 }
501 }
502 };
503
504 diffs
505}
506
507pub fn column_union(
511 typ: &mut ReprColumnType,
512 other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516 if diffs.is_empty() {
517 typ.nullable |= other.nullable;
518 }
519
520 diffs
521}
522
523pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528 if sub.len() != sup.len() {
529 return false;
530 }
531
532 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534 })
535}
536
537#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540 Null {
542 expected: ReprScalarType,
544 },
545 Mismatch {
547 got_debug: String,
550 expected: ReprScalarType,
552 },
553 MismatchDimensions {
555 ctor: String,
557 got: usize,
559 expected: usize,
561 },
562 ElementType {
564 ctor: String,
566 element_type: Box<DatumTypeDifference>,
568 },
569}
570
571fn datum_difference_with_column_type(
577 datum: &Datum<'_>,
578 column_type: &ReprColumnType,
579) -> Result<(), DatumTypeDifference> {
580 fn difference_with_scalar_type(
581 datum: &Datum<'_>,
582 scalar_type: &ReprScalarType,
583 ) -> Result<(), DatumTypeDifference> {
584 fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
585 Err(DatumTypeDifference::Mismatch {
586 got_debug: format!("{got:?}"),
588 expected: expected.clone(),
589 })
590 }
591
592 if let ReprScalarType::Jsonb = scalar_type {
593 match datum {
595 Datum::Dummy => Ok(()), Datum::Null => Err(DatumTypeDifference::Null {
597 expected: ReprScalarType::Jsonb,
598 }),
599 Datum::JsonNull
600 | Datum::False
601 | Datum::True
602 | Datum::Numeric(_)
603 | Datum::String(_) => Ok(()),
604 Datum::List(list) => {
605 for elem in list.iter() {
606 difference_with_scalar_type(&elem, scalar_type)?;
607 }
608 Ok(())
609 }
610 Datum::Map(dict) => {
611 for (_, val) in dict.iter() {
612 difference_with_scalar_type(&val, scalar_type)?;
613 }
614 Ok(())
615 }
616 _ => mismatch(datum, scalar_type),
617 }
618 } else {
619 fn element_type_difference(
620 ctor: &str,
621 element_type: DatumTypeDifference,
622 ) -> DatumTypeDifference {
623 DatumTypeDifference::ElementType {
624 ctor: ctor.to_string(),
625 element_type: Box::new(element_type),
626 }
627 }
628 match (datum, scalar_type) {
629 (Datum::Dummy, _) => Ok(()), (Datum::Null, _) => Err(DatumTypeDifference::Null {
631 expected: scalar_type.clone(),
632 }),
633 (Datum::False, ReprScalarType::Bool) => Ok(()),
634 (Datum::False, _) => mismatch(datum, scalar_type),
635 (Datum::True, ReprScalarType::Bool) => Ok(()),
636 (Datum::True, _) => mismatch(datum, scalar_type),
637 (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
638 (Datum::Int16(_), _) => mismatch(datum, scalar_type),
639 (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
640 (Datum::Int32(_), _) => mismatch(datum, scalar_type),
641 (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
642 (Datum::Int64(_), _) => mismatch(datum, scalar_type),
643 (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
644 (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
645 (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
646 (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
647 (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
648 (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
649 (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
650 (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
651 (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
652 (Datum::Float32(_), _) => mismatch(datum, scalar_type),
653 (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
654 (Datum::Float64(_), _) => mismatch(datum, scalar_type),
655 (Datum::Date(_), ReprScalarType::Date) => Ok(()),
656 (Datum::Date(_), _) => mismatch(datum, scalar_type),
657 (Datum::Time(_), ReprScalarType::Time) => Ok(()),
658 (Datum::Time(_), _) => mismatch(datum, scalar_type),
659 (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
660 (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
661 (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
662 (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
663 (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
664 (Datum::Interval(_), _) => mismatch(datum, scalar_type),
665 (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
666 (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
667 (Datum::String(_), ReprScalarType::String) => Ok(()),
668 (Datum::String(_), _) => mismatch(datum, scalar_type),
669 (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
670 (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
671 (Datum::Array(array), ReprScalarType::Array(t)) => {
672 for e in array.elements().iter() {
673 if let Datum::Null = e {
674 continue;
675 }
676
677 difference_with_scalar_type(&e, t)
678 .map_err(|e| element_type_difference("array", e))?;
679 }
680 Ok(())
681 }
682 (Datum::Array(array), ReprScalarType::Int2Vector) => {
683 if !array.has_int2vector_dims() {
684 return Err(DatumTypeDifference::MismatchDimensions {
687 ctor: "int2vector".to_string(),
688 got: array.dims().len(),
689 expected: 1,
690 });
691 }
692
693 for e in array.elements().iter() {
694 difference_with_scalar_type(&e, &ReprScalarType::Int16)
695 .map_err(|e| element_type_difference("int2vector", e))?;
696 }
697
698 Ok(())
699 }
700 (Datum::Array(_), _) => mismatch(datum, scalar_type),
701 (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
702 for e in list.iter() {
703 if let Datum::Null = e {
704 continue;
705 }
706
707 difference_with_scalar_type(&e, element_type)
708 .map_err(|e| element_type_difference("list", e))?;
709 }
710 Ok(())
711 }
712 (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
713 let len = list.iter().count();
714 if len != fields.len() {
715 return Err(DatumTypeDifference::MismatchDimensions {
716 ctor: "record".to_string(),
717 got: len,
718 expected: fields.len(),
719 });
720 }
721
722 for (e, t) in list.iter().zip_eq(fields) {
723 if let Datum::Null = e {
724 if t.nullable {
725 continue;
726 } else {
727 return Err(DatumTypeDifference::Null {
728 expected: t.scalar_type.clone(),
729 });
730 }
731 }
732
733 difference_with_scalar_type(&e, &t.scalar_type)
734 .map_err(|e| element_type_difference("record", e))?;
735 }
736 Ok(())
737 }
738 (Datum::List(_), _) => mismatch(datum, scalar_type),
739 (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
740 for (_, v) in map.iter() {
741 if let Datum::Null = v {
742 continue;
743 }
744
745 difference_with_scalar_type(&v, value_type)
746 .map_err(|e| element_type_difference("map", e))?;
747 }
748 Ok(())
749 }
750 (Datum::Map(_), _) => mismatch(datum, scalar_type),
751 (Datum::JsonNull, _) => mismatch(datum, scalar_type),
752 (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
753 (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
754 (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
755 (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
756 (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
757 match inner {
758 None => Ok(()),
759 Some(inner) => {
760 if let Some(b) = inner.lower.bound {
761 difference_with_scalar_type(&b.datum(), element_type)
762 .map_err(|e| element_type_difference("range", e))?;
763 }
764 if let Some(b) = inner.upper.bound {
765 difference_with_scalar_type(&b.datum(), element_type)
766 .map_err(|e| element_type_difference("range", e))?;
767 }
768 Ok(())
769 }
770 }
771 }
772 (Datum::Range(_), _) => mismatch(datum, scalar_type),
773 (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
774 (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
775 (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
776 (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
777 }
778 }
779 }
780 if column_type.nullable {
781 if let Datum::Null = datum {
782 return Ok(());
783 }
784 }
785 difference_with_scalar_type(datum, &column_type.scalar_type)
786}
787
788fn row_difference_with_column_types<'a>(
789 source: &'a MirRelationExpr,
790 datums: &Vec<Datum<'_>>,
791 column_types: &[ReprColumnType],
792) -> Result<(), TypeError<'a>> {
793 if datums.len() != column_types.len() {
795 return Err(TypeError::BadConstantRowLen {
796 source,
797 got: datums.len(),
798 expected: column_types.to_vec(),
799 });
800 }
801
802 let mut mismatches = Vec::new();
804 for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
805 if let Err(e) = datum_difference_with_column_type(d, ty) {
806 mismatches.push((i, e));
807 }
808 }
809 if !mismatches.is_empty() {
810 return Err(TypeError::BadConstantRow {
811 source,
812 mismatches,
813 expected: column_types.to_vec(),
814 });
815 }
816
817 Ok(())
818}
819#[derive(Debug)]
821pub struct Typecheck {
822 ctx: SharedContext,
824 disallow_new_globals: bool,
826 strict_join_equivalences: bool,
828 disallow_dummy: bool,
830 recursion_guard: RecursionGuard,
832}
833
834impl CheckedRecursion for Typecheck {
835 fn recursion_guard(&self) -> &RecursionGuard {
836 &self.recursion_guard
837 }
838}
839
840impl Typecheck {
841 pub fn new(ctx: SharedContext) -> Self {
843 Self {
844 ctx,
845 disallow_new_globals: false,
846 strict_join_equivalences: false,
847 disallow_dummy: false,
848 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
849 }
850 }
851
852 pub fn disallow_new_globals(mut self) -> Self {
856 self.disallow_new_globals = true;
857 self
858 }
859
860 pub fn strict_join_equivalences(mut self) -> Self {
864 self.strict_join_equivalences = true;
865
866 self
867 }
868
869 pub fn disallow_dummy(mut self) -> Self {
871 self.disallow_dummy = true;
872 self
873 }
874
875 pub fn typecheck<'a>(
886 &self,
887 expr: &'a MirRelationExpr,
888 ctx: &Context,
889 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
890 use MirRelationExpr::*;
891
892 self.checked_recur(|tc| match expr {
893 Constant { typ, rows } => {
894 if let Ok(rows) = rows {
895 for (row, _id) in rows {
896 let datums = row.unpack();
897
898 let col_types = typ
899 .column_types
900 .iter()
901 .map(ReprColumnType::from)
902 .collect_vec();
903 row_difference_with_column_types(
904 expr, &datums, &col_types,
905 )?;
906
907 if self.disallow_dummy
908 && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
909 {
910 return Err(TypeError::DisallowedDummy {
911 source: expr,
912 });
913 }
914 }
915 }
916
917 Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
918 }
919 Get { typ, id, .. } => {
920 if let Id::Global(_global_id) = id {
921 if !ctx.contains_key(id) {
922 return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
924 }
925 }
926
927 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
928 source: expr,
929 id: id.clone(),
930 typ: ReprRelationType::from(typ),
931 })?;
932
933 let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
934
935 let diffs = relation_subtype_difference(&column_types, ctx_typ)
937 .into_iter()
938 .flat_map(|diff| diff.ignore_nullability())
939 .collect::<Vec<_>>();
940
941 if !diffs.is_empty() {
942 return Err(TypeError::MismatchColumns {
943 source: expr,
944 got: column_types,
945 expected: ctx_typ.clone(),
946 diffs,
947 message: "annotation did not match context type".to_string(),
948 });
949 }
950
951 Ok(column_types)
952 }
953 Project { input, outputs } => {
954 let t_in = tc.typecheck(input, ctx)?;
955
956 for x in outputs {
957 if *x >= t_in.len() {
958 return Err(TypeError::BadProject {
959 source: expr,
960 got: outputs.clone(),
961 input_type: t_in,
962 });
963 }
964 }
965
966 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
967 }
968 Map { input, scalars } => {
969 let mut t_in = tc.typecheck(input, ctx)?;
970
971 for scalar_expr in scalars.iter() {
972 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
973
974 if self.disallow_dummy && scalar_expr.contains_dummy() {
975 return Err(TypeError::DisallowedDummy {
976 source: expr,
977 });
978 }
979 }
980
981 Ok(t_in)
982 }
983 FlatMap { input, func, exprs } => {
984 let mut t_in = tc.typecheck(input, ctx)?;
985
986 let mut t_exprs = Vec::with_capacity(exprs.len());
987 for scalar_expr in exprs {
988 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
989
990 if self.disallow_dummy && scalar_expr.contains_dummy() {
991 return Err(TypeError::DisallowedDummy {
992 source: expr,
993 });
994 }
995 }
996 let t_out = func
999 .output_type()
1000 .column_types
1001 .iter()
1002 .map(ReprColumnType::from)
1003 .collect_vec();
1004
1005 t_in.extend(t_out);
1007 Ok(t_in)
1008 }
1009 Filter { input, predicates } => {
1010 let mut t_in = tc.typecheck(input, ctx)?;
1011
1012 for column in non_nullable_columns(predicates) {
1015 t_in[column].nullable = false;
1016 }
1017
1018 for scalar_expr in predicates {
1019 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1020
1021 if t.scalar_type != ReprScalarType::Bool {
1025 let sub = t.scalar_type.clone();
1026
1027 return Err(TypeError::MismatchColumn {
1028 source: expr,
1029 got: t,
1030 expected: ReprColumnType {
1031 scalar_type: ReprScalarType::Bool,
1032 nullable: true,
1033 },
1034 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1035 sub,
1036 sup: ReprScalarType::Bool,
1037 }],
1038 message: "expected boolean condition".to_string(),
1039 });
1040 }
1041
1042 if self.disallow_dummy && scalar_expr.contains_dummy() {
1043 return Err(TypeError::DisallowedDummy {
1044 source: expr,
1045 });
1046 }
1047 }
1048
1049 Ok(t_in)
1050 }
1051 Join {
1052 inputs,
1053 equivalences,
1054 implementation,
1055 } => {
1056 let mut t_in_global = Vec::new();
1057 let mut t_in_local = vec![Vec::new(); inputs.len()];
1058
1059 for (i, input) in inputs.iter().enumerate() {
1060 let input_t = tc.typecheck(input, ctx)?;
1061 t_in_global.extend(input_t.clone());
1062 t_in_local[i] = input_t;
1063 }
1064
1065 for eq_class in equivalences {
1066 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1067
1068 let mut all_nullable = true;
1069
1070 for scalar_expr in eq_class {
1071 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1073
1074 if !t_expr.nullable {
1075 all_nullable = false;
1076 }
1077
1078 if let Some(t_first) = t_exprs.get(0) {
1079 let diffs = scalar_subtype_difference(
1080 &t_expr.scalar_type,
1081 &t_first.scalar_type,
1082 );
1083 if !diffs.is_empty() {
1084 return Err(TypeError::MismatchColumn {
1085 source: expr,
1086 got: t_expr,
1087 expected: t_first.clone(),
1088 diffs,
1089 message: "equivalence class members \
1090 have different scalar types"
1091 .to_string(),
1092 });
1093 }
1094
1095 if self.strict_join_equivalences {
1099 if t_expr.nullable != t_first.nullable {
1100 let sub = t_expr.clone();
1101 let sup = t_first.clone();
1102
1103 let err = TypeError::MismatchColumn {
1104 source: expr,
1105 got: t_expr.clone(),
1106 expected: t_first.clone(),
1107 diffs: vec![
1108 ReprColumnTypeDifference::Nullability { sub, sup },
1109 ],
1110 message: "equivalence class members have \
1111 different nullability (and join \
1112 equivalence checking is strict)"
1113 .to_string(),
1114 };
1115
1116 ::tracing::debug!("{err}");
1118 }
1119 }
1120 }
1121
1122 if self.disallow_dummy && scalar_expr.contains_dummy() {
1123 return Err(TypeError::DisallowedDummy {
1124 source: expr,
1125 });
1126 }
1127
1128 t_exprs.push(t_expr);
1129 }
1130
1131 if self.strict_join_equivalences && all_nullable {
1132 let err = TypeError::BadJoinEquivalence {
1133 source: expr,
1134 got: t_exprs,
1135 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1136 };
1137
1138 ::tracing::debug!("{err}");
1140 }
1141 }
1142
1143 match implementation {
1145 JoinImplementation::Differential((start_idx, first_key, _), others) => {
1146 if let Some(key) = first_key {
1147 for k in key {
1148 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1149 }
1150 }
1151
1152 for (idx, key, _) in others {
1153 for k in key {
1154 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1155 }
1156 }
1157 }
1158 JoinImplementation::DeltaQuery(plans) => {
1159 for plan in plans {
1160 for (idx, key, _) in plan {
1161 for k in key {
1162 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1163 }
1164 }
1165 }
1166 }
1167 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1168 let typ: Vec<ReprColumnType> = key
1169 .iter()
1170 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1171 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1172
1173 for row in consts {
1174 let datums = row.unpack();
1175
1176 row_difference_with_column_types(expr, &datums, &typ)?;
1177 }
1178 }
1179 JoinImplementation::Unimplemented => (),
1180 }
1181
1182 Ok(t_in_global)
1183 }
1184 Reduce {
1185 input,
1186 group_key,
1187 aggregates,
1188 monotonic: _,
1189 expected_group_size: _,
1190 } => {
1191 let t_in = tc.typecheck(input, ctx)?;
1192
1193 let mut t_out = group_key
1194 .iter()
1195 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1196 .collect::<Result<Vec<_>, _>>()?;
1197
1198 if self.disallow_dummy
1199 && group_key
1200 .iter()
1201 .any(|scalar_expr| scalar_expr.contains_dummy())
1202 {
1203 return Err(TypeError::DisallowedDummy {
1204 source: expr,
1205 });
1206 }
1207
1208 for agg in aggregates {
1209 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1210 }
1211
1212 Ok(t_out)
1213 }
1214 TopK {
1215 input,
1216 group_key,
1217 order_key,
1218 limit: _,
1219 offset: _,
1220 monotonic: _,
1221 expected_group_size: _,
1222 } => {
1223 let t_in = tc.typecheck(input, ctx)?;
1224
1225 for &k in group_key {
1226 if k >= t_in.len() {
1227 return Err(TypeError::BadTopKGroupKey {
1228 source: expr,
1229 k,
1230 input_type: t_in,
1231 });
1232 }
1233 }
1234
1235 for order in order_key {
1236 if order.column >= t_in.len() {
1237 return Err(TypeError::BadTopKOrdering {
1238 source: expr,
1239 order: order.clone(),
1240 input_type: t_in,
1241 });
1242 }
1243 }
1244
1245 Ok(t_in)
1246 }
1247 Negate { input } => tc.typecheck(input, ctx),
1248 Threshold { input } => tc.typecheck(input, ctx),
1249 Union { base, inputs } => {
1250 let mut t_base = tc.typecheck(base, ctx)?;
1251
1252 for input in inputs {
1253 let t_input = tc.typecheck(input, ctx)?;
1254
1255 let len_sub = t_base.len();
1256 let len_sup = t_input.len();
1257 if len_sub != len_sup {
1258 return Err(TypeError::MismatchColumns {
1259 source: expr,
1260 got: t_base.clone(),
1261 expected: t_input,
1262 diffs: vec![ReprRelationTypeDifference::Length {
1263 len_sub,
1264 len_sup,
1265 }],
1266 message: "Union branches have different numbers of columns".to_string(),
1267 });
1268 }
1269
1270 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1271 let diffs = column_union(base_col, &input_col);
1272 if !diffs.is_empty() {
1273 return Err(TypeError::MismatchColumn {
1274 source: expr,
1275 got: input_col,
1276 expected: base_col.clone(),
1277 diffs,
1278 message:
1279 "couldn't compute union of column types in Union"
1280 .to_string(),
1281 });
1282 }
1283
1284 }
1285 }
1286
1287 Ok(t_base)
1288 }
1289 Let { id, value, body } => {
1290 let t_value = tc.typecheck(value, ctx)?;
1291
1292 let binding = Id::Local(*id);
1293 if ctx.contains_key(&binding) {
1294 return Err(TypeError::Shadowing {
1295 source: expr,
1296 id: binding,
1297 });
1298 }
1299
1300 let mut body_ctx = ctx.clone();
1301 body_ctx.insert(Id::Local(*id), t_value);
1302
1303 tc.typecheck(body, &body_ctx)
1304 }
1305 LetRec { ids, values, body, limits: _ } => {
1306 if ids.len() != values.len() {
1307 return Err(TypeError::BadLetRecBindings { source: expr });
1308 }
1309
1310 let mut ctx = ctx.clone();
1313 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1315 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1316 }
1317
1318 for (id, value) in ids.iter().zip_eq(values.iter()) {
1319 let typ = tc.typecheck(value, &ctx)?;
1320
1321 let id = Id::Local(id.clone());
1322 if let Some(ctx_typ) = ctx.get_mut(&id) {
1323 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1324 let diffs = column_union(base_col, &input_col);
1326 if !diffs.is_empty() {
1327 return Err(TypeError::MismatchColumn {
1328 source: expr,
1329 got: input_col,
1330 expected: base_col.clone(),
1331 diffs,
1332 message:
1333 "couldn't compute union of column types in LetRec"
1334 .to_string(),
1335 })
1336 }
1337 }
1338 } else {
1339 ctx.insert(id, typ);
1341 }
1342 }
1343
1344 tc.typecheck(body, &ctx)
1345 }
1346 ArrangeBy { input, keys } => {
1347 let t_in = tc.typecheck(input, ctx)?;
1348
1349 for key in keys {
1350 for k in key {
1351 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1352 }
1353 }
1354
1355 Ok(t_in)
1356 }
1357 })
1358 }
1359
1360 fn collect_recursive_variable_types<'a>(
1364 &self,
1365 expr: &'a MirRelationExpr,
1366 ids: &[LocalId],
1367 ctx: &mut Context,
1368 ) -> Result<(), TypeError<'a>> {
1369 use MirRelationExpr::*;
1370
1371 self.checked_recur(|tc| {
1372 match expr {
1373 Get {
1374 id: Id::Local(id),
1375 typ,
1376 ..
1377 } => {
1378 if !ids.contains(id) {
1379 return Ok(());
1380 }
1381
1382 let id = Id::Local(id.clone());
1383 if let Some(ctx_typ) = ctx.get_mut(&id) {
1384 let typ = typ
1385 .column_types
1386 .iter()
1387 .map(ReprColumnType::from)
1388 .collect_vec();
1389
1390 if ctx_typ.len() != typ.len() {
1391 let diffs = relation_subtype_difference(&typ, ctx_typ);
1392
1393 return Err(TypeError::MismatchColumns {
1394 source: expr,
1395 got: typ,
1396 expected: ctx_typ.clone(),
1397 diffs,
1398 message: "environment and type annotation did not match"
1399 .to_string(),
1400 });
1401 }
1402
1403 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1404 let diffs = column_union(base_col, &input_col);
1405 if !diffs.is_empty() {
1406 return Err(TypeError::MismatchColumn {
1407 source: expr,
1408 got: input_col,
1409 expected: base_col.clone(),
1410 diffs,
1411 message:
1412 "couldn't compute union of column types in Get and context"
1413 .to_string(),
1414 });
1415 }
1416 }
1417 } else {
1418 ctx.insert(
1419 id,
1420 typ.column_types
1421 .iter()
1422 .map(ReprColumnType::from)
1423 .collect_vec(),
1424 );
1425 }
1426 }
1427 Get {
1428 id: Id::Global(..), ..
1429 }
1430 | Constant { .. } => (),
1431 Let { id, value, body } => {
1432 tc.collect_recursive_variable_types(value, ids, ctx)?;
1433
1434 if ids.contains(id) {
1436 return Err(TypeError::Shadowing {
1437 source: expr,
1438 id: Id::Local(*id),
1439 });
1440 }
1441
1442 tc.collect_recursive_variable_types(body, ids, ctx)?;
1443 }
1444 LetRec {
1445 ids: inner_ids,
1446 values,
1447 body,
1448 limits: _,
1449 } => {
1450 for inner_id in inner_ids {
1451 if ids.contains(inner_id) {
1452 return Err(TypeError::Shadowing {
1453 source: expr,
1454 id: Id::Local(*inner_id),
1455 });
1456 }
1457 }
1458
1459 for value in values {
1460 tc.collect_recursive_variable_types(value, ids, ctx)?;
1461 }
1462
1463 tc.collect_recursive_variable_types(body, ids, ctx)?;
1464 }
1465 Project { input, .. }
1466 | Map { input, .. }
1467 | FlatMap { input, .. }
1468 | Filter { input, .. }
1469 | Reduce { input, .. }
1470 | TopK { input, .. }
1471 | Negate { input }
1472 | Threshold { input }
1473 | ArrangeBy { input, .. } => {
1474 tc.collect_recursive_variable_types(input, ids, ctx)?;
1475 }
1476 Join { inputs, .. } => {
1477 for input in inputs {
1478 tc.collect_recursive_variable_types(input, ids, ctx)?;
1479 }
1480 }
1481 Union { base, inputs } => {
1482 tc.collect_recursive_variable_types(base, ids, ctx)?;
1483
1484 for input in inputs {
1485 tc.collect_recursive_variable_types(input, ids, ctx)?;
1486 }
1487 }
1488 }
1489
1490 Ok(())
1491 })
1492 }
1493
1494 fn typecheck_scalar<'a>(
1495 &self,
1496 expr: &'a MirScalarExpr,
1497 source: &'a MirRelationExpr,
1498 column_types: &[ReprColumnType],
1499 ) -> Result<ReprColumnType, TypeError<'a>> {
1500 use MirScalarExpr::*;
1501
1502 self.checked_recur(|tc| match expr {
1503 Column(i, _) => match column_types.get(*i) {
1504 Some(ty) => Ok(ty.clone()),
1505 None => Err(TypeError::NoSuchColumn {
1506 source,
1507 expr,
1508 col: *i,
1509 }),
1510 },
1511 Literal(row, typ) => {
1512 let typ = ReprColumnType::from(typ);
1513 if let Ok(row) = row {
1514 let datums = row.unpack();
1515
1516 row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1517 }
1518
1519 Ok(typ)
1520 }
1521 CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1522 CallUnary { expr, func } => {
1523 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1524 let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1525 Ok(ReprColumnType::from(&typ_out))
1526 }
1527 CallBinary { expr1, expr2, func } => {
1528 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1529 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1530 let typ_out = func.output_type(&[
1531 SqlColumnType::from_repr(&typ_in1),
1532 SqlColumnType::from_repr(&typ_in2),
1533 ]);
1534 Ok(ReprColumnType::from(&typ_out))
1535 }
1536 CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1537 &func.output_type(
1538 exprs
1539 .iter()
1540 .map(|e| {
1541 tc.typecheck_scalar(e, source, column_types)
1542 .map(|typ| SqlColumnType::from_repr(&typ))
1543 })
1544 .collect::<Result<Vec<_>, TypeError>>()?,
1545 ),
1546 )),
1547 If { cond, then, els } => {
1548 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1549
1550 if cond_type.scalar_type != ReprScalarType::Bool {
1554 let sub = cond_type.scalar_type.clone();
1555
1556 return Err(TypeError::MismatchColumn {
1557 source,
1558 got: cond_type,
1559 expected: ReprColumnType {
1560 scalar_type: ReprScalarType::Bool,
1561 nullable: true,
1562 },
1563 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1564 sub,
1565 sup: ReprScalarType::Bool,
1566 }],
1567 message: "expected boolean condition".to_string(),
1568 });
1569 }
1570
1571 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1572 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1573
1574 let diffs = column_union(&mut then_type, &else_type);
1575 if !diffs.is_empty() {
1576 return Err(TypeError::MismatchColumn {
1577 source,
1578 got: then_type,
1579 expected: else_type,
1580 diffs,
1581 message: "couldn't compute union of column types for If".to_string(),
1582 });
1583 }
1584
1585 Ok(then_type)
1586 }
1587 })
1588 }
1589
1590 pub fn typecheck_aggregate<'a>(
1592 &self,
1593 expr: &'a AggregateExpr,
1594 source: &'a MirRelationExpr,
1595 column_types: &[ReprColumnType],
1596 ) -> Result<ReprColumnType, TypeError<'a>> {
1597 self.checked_recur(|tc| {
1598 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1599
1600 Ok(ReprColumnType::from(
1603 &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1604 ))
1605 })
1606 }
1607}
1608
1609macro_rules! type_error {
1613 ($severity:expr, $($arg:tt)+) => {{
1614 if $severity {
1615 soft_panic_or_log!($($arg)+);
1616 } else {
1617 ::tracing::debug!($($arg)+);
1618 }
1619 }}
1620}
1621
1622impl crate::Transform for Typecheck {
1623 fn name(&self) -> &'static str {
1624 "Typecheck"
1625 }
1626
1627 fn actually_perform_transform(
1628 &self,
1629 relation: &mut MirRelationExpr,
1630 transform_ctx: &mut crate::TransformCtx,
1631 ) -> Result<(), crate::TransformError> {
1632 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1633
1634 let expected = transform_ctx
1635 .global_id
1636 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1637
1638 if let Some(id) = transform_ctx.global_id {
1639 if self.disallow_new_globals
1640 && expected.is_none()
1641 && transform_ctx.global_id.is_some()
1642 && !id.is_transient()
1643 {
1644 type_error!(
1645 false, "type warning: new non-transient global id {id}\n{}",
1647 relation.pretty()
1648 );
1649 }
1650 }
1651
1652 let got = self.typecheck(relation, &typecheck_ctx);
1653
1654 let humanizer = mz_repr::explain::DummyHumanizer;
1655
1656 match (got, expected) {
1657 (Ok(got), Some(expected)) => {
1658 let id = transform_ctx.global_id.unwrap();
1659
1660 let diffs = relation_subtype_difference(expected, &got);
1662 if !diffs.is_empty() {
1663 let severity = diffs
1665 .iter()
1666 .any(|diff| diff.clone().ignore_nullability().is_some());
1667
1668 let err = TypeError::MismatchColumns {
1669 source: relation,
1670 got,
1671 expected: expected.clone(),
1672 diffs,
1673 message: format!(
1674 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1675 ),
1676 };
1677
1678 type_error!(severity, "type error in known global id {id}:\n{err}");
1679 }
1680 }
1681 (Ok(got), None) => {
1682 if let Some(id) = transform_ctx.global_id {
1683 typecheck_ctx.insert(Id::Global(id), got);
1684 }
1685 }
1686 (Err(err), _) => {
1687 let (expected, binding) = match expected {
1688 Some(expected) => {
1689 let id = transform_ctx.global_id.unwrap();
1690 (
1691 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1692 format!("known global id {id}"),
1693 )
1694 }
1695 None => ("".to_string(), "transient query".to_string()),
1696 };
1697
1698 type_error!(
1699 true, "type error in {binding}:\n{err}\n{expected}{}",
1701 relation.pretty()
1702 );
1703 }
1704 }
1705
1706 Ok(())
1707 }
1708}
1709
1710pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1712where
1713 H: ExprHumanizer,
1714{
1715 let mut s = String::with_capacity(2 + 3 * cols.len());
1716
1717 s.push('(');
1718
1719 let mut it = cols.iter().peekable();
1720 while let Some(col) = it.next() {
1721 s.push_str(&humanizer.humanize_column_type_repr(col, false));
1722
1723 if it.peek().is_some() {
1724 s.push_str(", ");
1725 }
1726 }
1727
1728 s.push(')');
1729
1730 s
1731}
1732
1733impl ReprRelationTypeDifference {
1734 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1738 where
1739 H: ExprHumanizer,
1740 {
1741 use ReprRelationTypeDifference::*;
1742 match self {
1743 Length { len_sub, len_sup } => {
1744 writeln!(
1745 f,
1746 " number of columns do not match ({len_sub} != {len_sup})"
1747 )
1748 }
1749 Column { col, diff } => {
1750 writeln!(f, " column {col} differs:")?;
1751 diff.humanize(4, h, f)
1752 }
1753 }
1754 }
1755}
1756
1757impl ReprColumnTypeDifference {
1758 pub fn humanize<H>(
1760 &self,
1761 indent: usize,
1762 h: &H,
1763 f: &mut std::fmt::Formatter<'_>,
1764 ) -> std::fmt::Result
1765 where
1766 H: ExprHumanizer,
1767 {
1768 use ReprColumnTypeDifference::*;
1769
1770 write!(f, "{:indent$}", "")?;
1772
1773 match self {
1774 NotSubtype { sub, sup } => {
1775 let sub = h.humanize_scalar_type_repr(sub, false);
1776 let sup = h.humanize_scalar_type_repr(sup, false);
1777
1778 writeln!(f, "{sub} is a not a subtype of {sup}")
1779 }
1780 Nullability { sub, sup } => {
1781 let sub = h.humanize_column_type_repr(sub, false);
1782 let sup = h.humanize_column_type_repr(sup, false);
1783
1784 writeln!(f, "{sub} is nullable but {sup} is not")
1785 }
1786 ElementType { ctor, element_type } => {
1787 writeln!(f, "{ctor} element types differ:")?;
1788
1789 element_type.humanize(indent + 2, h, f)
1790 }
1791 RecordMissingFields { missing } => {
1792 write!(f, "missing column fields:")?;
1793 for col in missing {
1794 write!(f, " {col}")?;
1795 }
1796 f.write_char('\n')
1797 }
1798 RecordFields { fields } => {
1799 writeln!(f, "{} record fields differ:", fields.len())?;
1800
1801 for (i, diff) in fields.iter().enumerate() {
1802 writeln!(f, "{:indent$} field {i}:", "")?;
1803 diff.humanize(indent + 4, h, f)?;
1804 }
1805 Ok(())
1806 }
1807 }
1808 }
1809}
1810
1811impl DatumTypeDifference {
1812 pub fn humanize<H>(
1814 &self,
1815 indent: usize,
1816 h: &H,
1817 f: &mut std::fmt::Formatter<'_>,
1818 ) -> std::fmt::Result
1819 where
1820 H: ExprHumanizer,
1821 {
1822 write!(f, "{:indent$}", "")?;
1824
1825 match self {
1826 DatumTypeDifference::Null { expected } => {
1827 let expected = h.humanize_scalar_type_repr(expected, false);
1828 writeln!(
1829 f,
1830 "unexpected null, expected representation type {expected}"
1831 )?
1832 }
1833 DatumTypeDifference::Mismatch {
1834 got_debug,
1835 expected,
1836 } => {
1837 let expected = h.humanize_scalar_type_repr(expected, false);
1838 writeln!(
1840 f,
1841 "got datum {got_debug}, expected representation type {expected}"
1842 )?;
1843 }
1844 DatumTypeDifference::MismatchDimensions {
1845 ctor,
1846 got,
1847 expected,
1848 } => {
1849 writeln!(
1850 f,
1851 "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1852 )?;
1853 }
1854 DatumTypeDifference::ElementType { ctor, element_type } => {
1855 writeln!(f, "{ctor} element types differ:")?;
1856 element_type.humanize(indent + 4, h, f)?;
1857 }
1858 }
1859
1860 Ok(())
1861 }
1862}
1863
1864#[allow(missing_debug_implementations)]
1866pub struct TypeErrorHumanizer<'a, 'b, H>
1867where
1868 H: ExprHumanizer,
1869{
1870 err: &'a TypeError<'a>,
1871 humanizer: &'b H,
1872}
1873
1874impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1875where
1876 H: ExprHumanizer,
1877{
1878 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1880 Self { err, humanizer }
1881 }
1882}
1883
1884impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1885where
1886 H: ExprHumanizer,
1887{
1888 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1889 self.err.humanize(self.humanizer, f)
1890 }
1891}
1892
1893impl<'a> std::fmt::Display for TypeError<'a> {
1894 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1895 TypeErrorHumanizer {
1896 err: self,
1897 humanizer: &DummyHumanizer,
1898 }
1899 .fmt(f)
1900 }
1901}
1902
1903impl<'a> TypeError<'a> {
1904 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1906 use TypeError::*;
1907 match self {
1908 Unbound { source, .. }
1909 | NoSuchColumn { source, .. }
1910 | MismatchColumn { source, .. }
1911 | MismatchColumns { source, .. }
1912 | BadConstantRowLen { source, .. }
1913 | BadConstantRow { source, .. }
1914 | BadProject { source, .. }
1915 | BadJoinEquivalence { source, .. }
1916 | BadTopKGroupKey { source, .. }
1917 | BadTopKOrdering { source, .. }
1918 | BadLetRecBindings { source }
1919 | Shadowing { source, .. }
1920 | DisallowedDummy { source, .. } => Some(source),
1921 Recursion { .. } => None,
1922 }
1923 }
1924
1925 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1926 where
1927 H: ExprHumanizer,
1928 {
1929 if let Some(source) = self.source() {
1930 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1931 }
1932
1933 use TypeError::*;
1934 match self {
1935 Unbound { source: _, id, typ } => {
1936 let typ = columns_pretty(&typ.column_types, humanizer);
1937 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1938 }
1939 NoSuchColumn {
1940 source: _,
1941 expr,
1942 col,
1943 } => writeln!(f, "{expr} references non-existent column {col}")?,
1944 MismatchColumn {
1945 source: _,
1946 got,
1947 expected,
1948 diffs,
1949 message,
1950 } => {
1951 let got = humanizer.humanize_column_type_repr(got, false);
1952 let expected = humanizer.humanize_column_type_repr(expected, false);
1953 writeln!(
1954 f,
1955 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1956 )?;
1957
1958 for diff in diffs {
1959 diff.humanize(2, humanizer, f)?;
1960 }
1961 }
1962 MismatchColumns {
1963 source: _,
1964 got,
1965 expected,
1966 diffs,
1967 message,
1968 } => {
1969 let got = columns_pretty(got, humanizer);
1970 let expected = columns_pretty(expected, humanizer);
1971
1972 writeln!(
1973 f,
1974 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1975 )?;
1976
1977 for diff in diffs {
1978 diff.humanize(humanizer, f)?;
1979 }
1980 }
1981 BadConstantRowLen {
1982 source: _,
1983 got,
1984 expected,
1985 } => {
1986 let expected = columns_pretty(expected, humanizer);
1987 writeln!(
1988 f,
1989 "bad constant row\n row has length {got}\nexpected row of type {expected}"
1990 )?
1991 }
1992 BadConstantRow {
1993 source: _,
1994 mismatches,
1995 expected,
1996 } => {
1997 let expected = columns_pretty(expected, humanizer);
1998
1999 let num_mismatches = mismatches.len();
2000 let plural = if num_mismatches == 1 { "" } else { "es" };
2001 writeln!(
2002 f,
2003 "bad constant row\n got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
2004 )?;
2005
2006 if num_mismatches > 0 {
2007 writeln!(f, "")?;
2008 for (col, diff) in mismatches.iter() {
2009 writeln!(f, " column #{col}:")?;
2010 diff.humanize(8, humanizer, f)?;
2011 }
2012 }
2013 }
2014 BadProject {
2015 source: _,
2016 got,
2017 input_type,
2018 } => {
2019 let input_type = columns_pretty(input_type, humanizer);
2020
2021 writeln!(
2022 f,
2023 "projection of non-existant columns {got:?} from type {input_type}"
2024 )?
2025 }
2026 BadJoinEquivalence {
2027 source: _,
2028 got,
2029 message,
2030 } => {
2031 let got = columns_pretty(got, humanizer);
2032
2033 writeln!(f, "bad join equivalence {got}: {message}")?
2034 }
2035 BadTopKGroupKey {
2036 source: _,
2037 k,
2038 input_type,
2039 } => {
2040 let input_type = columns_pretty(input_type, humanizer);
2041
2042 writeln!(
2043 f,
2044 "TopK group key component references invalid column {k} in columns: {input_type}"
2045 )?
2046 }
2047 BadTopKOrdering {
2048 source: _,
2049 order,
2050 input_type,
2051 } => {
2052 let col = order.column;
2053 let num_cols = input_type.len();
2054 let are = if num_cols == 1 { "is" } else { "are" };
2055 let s = if num_cols == 1 { "" } else { "s" };
2056 let input_type = columns_pretty(input_type, humanizer);
2057
2058 let mode = HumanizedExplain::new(false);
2060 let order = mode.expr(order, None);
2061
2062 writeln!(
2063 f,
2064 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2065 )?
2066 }
2067 BadLetRecBindings { source: _ } => {
2068 writeln!(f, "LetRec ids and definitions don't line up")?
2069 }
2070 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2071 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2072 Recursion { error } => writeln!(f, "{error}")?,
2073 }
2074
2075 Ok(())
2076 }
2077}
2078
2079#[cfg(test)]
2080mod tests {
2081 use mz_ore::{assert_err, assert_ok};
2082 use mz_repr::{arb_datum, arb_datum_for_column};
2083 use proptest::prelude::*;
2084
2085 use super::*;
2086
2087 #[mz_ore::test]
2088 fn test_datum_type_difference() {
2089 let datum = Datum::Int16(1);
2090
2091 assert_ok!(datum_difference_with_column_type(
2092 &datum,
2093 &ReprColumnType {
2094 scalar_type: ReprScalarType::Int16,
2095 nullable: true,
2096 }
2097 ));
2098
2099 assert_err!(datum_difference_with_column_type(
2100 &datum,
2101 &ReprColumnType {
2102 scalar_type: ReprScalarType::Int32,
2103 nullable: false,
2104 }
2105 ));
2106 }
2107
2108 proptest! {
2109 #![proptest_config(ProptestConfig {
2110 cases: 5000,
2111 max_global_rejects: 2500,
2112 ..Default::default()
2113 })]
2114 #[mz_ore::test]
2115 #[cfg_attr(miri, ignore)]
2116 fn datum_type_difference_with_instance_of_on_valid_data(
2117 (src, datum) in any::<SqlColumnType>()
2118 .prop_flat_map(|src| {
2119 let datum = arb_datum_for_column(src.clone());
2120 (Just(src), datum)
2121 })
2122 ) {
2123 let typ = ReprColumnType::from(&src);
2124 let datum = Datum::from(&datum);
2125
2126 if datum.contains_dummy() {
2127 return Err(TestCaseError::reject("datum contains a dummy"));
2128 }
2129
2130 let diff = datum_difference_with_column_type(&datum, &typ);
2131 if datum.is_instance_of(&typ) {
2132 assert_ok!(diff);
2133 } else {
2134 assert_err!(diff);
2135 }
2136 }
2137 }
2138
2139 proptest! {
2140 #![proptest_config(ProptestConfig::with_cases(10000))]
2143 #[mz_ore::test]
2144 #[cfg_attr(miri, ignore)]
2145 fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2146 src in any::<SqlColumnType>(),
2147 datum in arb_datum(false),
2148 ) {
2149 let typ = ReprColumnType::from(&src);
2150 let datum = Datum::from(&datum);
2151
2152 assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2153
2154 let diff = datum_difference_with_column_type(&datum, &typ);
2155 if datum.is_instance_of(&typ) {
2156 assert_ok!(diff);
2157 } else {
2158 assert_err!(diff);
2159 }
2160 }
2161 }
2162
2163 #[mz_ore::test]
2164 fn datum_type_difference_github_10039() {
2165 let typ = ReprColumnType {
2166 scalar_type: ReprScalarType::Record {
2167 fields: Box::new([ReprColumnType {
2168 scalar_type: ReprScalarType::UInt32,
2169 nullable: false,
2170 }]),
2171 },
2172 nullable: false,
2173 };
2174
2175 let mut row = mz_repr::Row::default();
2176 row.packer()
2177 .push_list(std::iter::once(mz_repr::Datum::Null));
2178 let datum = row.unpack_first();
2179
2180 assert!(!datum.is_instance_of(&typ));
2181 let diff = datum_difference_with_column_type(&datum, &typ);
2182 assert_err!(diff);
2183 }
2184}