1use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19 AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20 RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27 ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28 SqlColumnType,
29};
30
31pub type SharedContext = Arc<Mutex<Context>>;
36
37pub fn empty_context() -> SharedContext {
39 Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48 Unbound {
50 source: &'a MirRelationExpr,
52 id: Id,
54 typ: ReprRelationType,
56 },
57 NoSuchColumn {
59 source: &'a MirRelationExpr,
61 expr: &'a MirScalarExpr,
63 col: usize,
65 },
66 MismatchColumn {
68 source: &'a MirRelationExpr,
70 got: ReprColumnType,
72 expected: ReprColumnType,
74 diffs: Vec<ReprColumnTypeDifference>,
76 message: String,
78 },
79 MismatchColumns {
81 source: &'a MirRelationExpr,
83 got: Vec<ReprColumnType>,
85 expected: Vec<ReprColumnType>,
87 diffs: Vec<ReprRelationTypeDifference>,
89 message: String,
91 },
92 BadConstantRowLen {
94 source: &'a MirRelationExpr,
96 got: usize,
98 expected: Vec<ReprColumnType>,
100 },
101 BadConstantRow {
103 source: &'a MirRelationExpr,
105 mismatches: Vec<(usize, DatumTypeDifference)>,
107 expected: Vec<ReprColumnType>,
109 },
112 BadProject {
114 source: &'a MirRelationExpr,
116 got: Vec<usize>,
118 input_type: Vec<ReprColumnType>,
120 },
121 BadJoinEquivalence {
123 source: &'a MirRelationExpr,
125 got: Vec<ReprColumnType>,
127 message: String,
129 },
130 BadTopKGroupKey {
132 source: &'a MirRelationExpr,
134 k: usize,
136 input_type: Vec<ReprColumnType>,
138 },
139 BadTopKOrdering {
141 source: &'a MirRelationExpr,
143 order: ColumnOrder,
145 input_type: Vec<ReprColumnType>,
147 },
148 BadLetRecBindings {
150 source: &'a MirRelationExpr,
152 },
153 Shadowing {
155 source: &'a MirRelationExpr,
157 id: Id,
159 },
160 Recursion {
162 error: RecursionLimitError,
164 },
165 DisallowedDummy {
167 source: &'a MirRelationExpr,
169 },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173 fn from(error: RecursionLimitError) -> Self {
174 TypeError::Recursion { error }
175 }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185 Length {
187 len_sub: usize,
189 len_sup: usize,
191 },
192 Column {
194 col: usize,
196 diff: ReprColumnTypeDifference,
198 },
199}
200
201#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207 NotSubtype {
209 sub: ReprScalarType,
211 sup: ReprScalarType,
213 },
214 Nullability {
216 sub: ReprColumnType,
218 sup: ReprColumnType,
220 },
221 ElementType {
223 ctor: String,
225 element_type: Box<ReprColumnTypeDifference>,
227 },
228 RecordMissingFields {
230 missing: Vec<ColumnName>,
232 },
233 RecordFields {
235 fields: Vec<ReprColumnTypeDifference>,
237 },
238}
239
240impl ReprRelationTypeDifference {
241 pub fn ignore_nullability(self) -> Option<Self> {
245 use ReprRelationTypeDifference::*;
246
247 match self {
248 Length { .. } => Some(self),
249 Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250 }
251 }
252}
253
254impl ReprColumnTypeDifference {
255 pub fn ignore_nullability(self) -> Option<Self> {
259 use ReprColumnTypeDifference::*;
260
261 match self {
262 Nullability { .. } => None,
263 NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264 ElementType { ctor, element_type } => {
265 element_type
266 .ignore_nullability()
267 .map(|element_type| ElementType {
268 ctor,
269 element_type: Box::new(element_type),
270 })
271 }
272 RecordFields { fields } => {
273 let fields = fields
274 .into_iter()
275 .flat_map(|diff| diff.ignore_nullability())
276 .collect::<Vec<_>>();
277
278 if fields.is_empty() {
279 None
280 } else {
281 Some(RecordFields { fields })
282 }
283 }
284 }
285 }
286}
287
288pub fn relation_subtype_difference(
292 sub: &[ReprColumnType],
293 sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295 let mut diffs = Vec::new();
296
297 if sub.len() != sup.len() {
298 diffs.push(ReprRelationTypeDifference::Length {
299 len_sub: sub.len(),
300 len_sup: sup.len(),
301 });
302
303 return diffs;
305 }
306
307 diffs.extend(
308 sub.iter()
309 .zip_eq(sup.iter())
310 .enumerate()
311 .flat_map(|(col, (sub_ty, sup_ty))| {
312 column_subtype_difference(sub_ty, sup_ty)
313 .into_iter()
314 .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315 }),
316 );
317
318 diffs
319}
320
321pub fn column_subtype_difference(
325 sub: &ReprColumnType,
326 sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328 let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330 if sub.nullable && !sup.nullable {
331 diffs.push(ReprColumnTypeDifference::Nullability {
332 sub: sub.clone(),
333 sup: sup.clone(),
334 });
335 }
336
337 diffs
338}
339
340pub fn scalar_subtype_difference(
344 sub: &ReprScalarType,
345 sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347 use ReprScalarType::*;
348
349 let mut diffs = Vec::new();
350
351 match (sub, sup) {
352 (
353 List {
354 element_type: sub_elt,
355 ..
356 },
357 List {
358 element_type: sup_elt,
359 ..
360 },
361 )
362 | (
363 Map {
364 value_type: sub_elt,
365 ..
366 },
367 Map {
368 value_type: sup_elt,
369 ..
370 },
371 )
372 | (
373 Range {
374 element_type: sub_elt,
375 ..
376 },
377 Range {
378 element_type: sup_elt,
379 ..
380 },
381 )
382 | (Array(sub_elt), Array(sup_elt)) => {
383 let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384 diffs.extend(
385 scalar_subtype_difference(sub_elt, sup_elt)
386 .into_iter()
387 .map(|diff| ReprColumnTypeDifference::ElementType {
388 ctor: ctor.clone(),
389 element_type: Box::new(diff),
390 }),
391 );
392 }
393 (
394 Record {
395 fields: sub_fields, ..
396 },
397 Record {
398 fields: sup_fields, ..
399 },
400 ) => {
401 if sub_fields.len() != sup_fields.len() {
402 diffs.push(ReprColumnTypeDifference::NotSubtype {
403 sub: sub.clone(),
404 sup: sup.clone(),
405 });
406 return diffs;
407 }
408
409 for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410 diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411 }
412 }
413 (_, _) => {
414 if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415 diffs.push(ReprColumnTypeDifference::NotSubtype {
416 sub: sub.clone(),
417 sup: sup.clone(),
418 })
419 }
420 }
421 };
422
423 diffs
424}
425
426pub fn scalar_union(
430 typ: &mut ReprScalarType,
431 other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433 use ReprScalarType::*;
434
435 let mut diffs = Vec::new();
436
437 let ctor = ReprScalarBaseType::from(&*typ);
439 match (typ, other) {
440 (
441 List {
442 element_type: typ_elt,
443 },
444 List {
445 element_type: other_elt,
446 },
447 )
448 | (
449 Map {
450 value_type: typ_elt,
451 },
452 Map {
453 value_type: other_elt,
454 },
455 )
456 | (
457 Range {
458 element_type: typ_elt,
459 },
460 Range {
461 element_type: other_elt,
462 },
463 )
464 | (Array(typ_elt), Array(other_elt)) => {
465 let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466 diffs.extend(
467 res.into_iter()
468 .map(|diff| ReprColumnTypeDifference::ElementType {
469 ctor: format!("{ctor:?}"),
470 element_type: Box::new(diff),
471 }),
472 );
473 }
474 (
475 Record { fields: typ_fields },
476 Record {
477 fields: other_fields,
478 },
479 ) => {
480 if typ_fields.len() != other_fields.len() {
481 diffs.push(ReprColumnTypeDifference::NotSubtype {
482 sub: ReprScalarType::Record {
483 fields: typ_fields.clone(),
484 },
485 sup: other.clone(),
486 });
487 return diffs;
488 }
489
490 for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491 diffs.extend(column_union(typ_ty, other_ty));
492 }
493 }
494 (typ, _) => {
495 if ctor != ReprScalarBaseType::from(other) {
496 diffs.push(ReprColumnTypeDifference::NotSubtype {
497 sub: typ.clone(),
498 sup: other.clone(),
499 })
500 }
501 }
502 };
503
504 diffs
505}
506
507pub fn column_union(
511 typ: &mut ReprColumnType,
512 other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514 let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516 if diffs.is_empty() {
517 typ.nullable |= other.nullable;
518 }
519
520 diffs
521}
522
523pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528 if sub.len() != sup.len() {
529 return false;
530 }
531
532 sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533 (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534 })
535}
536
537#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540 Null,
542 Mismatch {
544 got_debug: String,
547 expected: ReprScalarType,
549 },
550 MismatchDimensions {
552 ctor: String,
554 got: usize,
556 expected: usize,
558 },
559 ElementType {
561 ctor: String,
563 element_type: Box<DatumTypeDifference>,
565 },
566}
567
568fn datum_difference_with_column_type(
574 datum: &Datum<'_>,
575 column_type: &ReprColumnType,
576) -> Result<(), DatumTypeDifference> {
577 fn difference_with_scalar_type(
578 datum: &Datum<'_>,
579 scalar_type: &ReprScalarType,
580 ) -> Result<(), DatumTypeDifference> {
581 fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
582 Err(DatumTypeDifference::Mismatch {
583 got_debug: format!("{got:?}"),
585 expected: expected.clone(),
586 })
587 }
588
589 if let ReprScalarType::Jsonb = scalar_type {
590 match datum {
592 Datum::Dummy => Ok(()), Datum::Null => Err(DatumTypeDifference::Null),
594 Datum::JsonNull
595 | Datum::False
596 | Datum::True
597 | Datum::Numeric(_)
598 | Datum::String(_) => Ok(()),
599 Datum::List(list) => {
600 for elem in list.iter() {
601 difference_with_scalar_type(&elem, scalar_type)?;
602 }
603 Ok(())
604 }
605 Datum::Map(dict) => {
606 for (_, val) in dict.iter() {
607 difference_with_scalar_type(&val, scalar_type)?;
608 }
609 Ok(())
610 }
611 _ => mismatch(datum, scalar_type),
612 }
613 } else {
614 fn element_type_difference(
615 ctor: &str,
616 element_type: DatumTypeDifference,
617 ) -> DatumTypeDifference {
618 DatumTypeDifference::ElementType {
619 ctor: ctor.to_string(),
620 element_type: Box::new(element_type),
621 }
622 }
623 match (datum, scalar_type) {
624 (Datum::Dummy, _) => Ok(()), (Datum::Null, _) => Err(DatumTypeDifference::Null),
626 (Datum::False, ReprScalarType::Bool) => Ok(()),
627 (Datum::False, _) => mismatch(datum, scalar_type),
628 (Datum::True, ReprScalarType::Bool) => Ok(()),
629 (Datum::True, _) => mismatch(datum, scalar_type),
630 (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
631 (Datum::Int16(_), _) => mismatch(datum, scalar_type),
632 (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
633 (Datum::Int32(_), _) => mismatch(datum, scalar_type),
634 (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
635 (Datum::Int64(_), _) => mismatch(datum, scalar_type),
636 (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
637 (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
638 (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
639 (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
640 (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
641 (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
642 (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
643 (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
644 (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
645 (Datum::Float32(_), _) => mismatch(datum, scalar_type),
646 (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
647 (Datum::Float64(_), _) => mismatch(datum, scalar_type),
648 (Datum::Date(_), ReprScalarType::Date) => Ok(()),
649 (Datum::Date(_), _) => mismatch(datum, scalar_type),
650 (Datum::Time(_), ReprScalarType::Time) => Ok(()),
651 (Datum::Time(_), _) => mismatch(datum, scalar_type),
652 (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
653 (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
654 (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
655 (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
656 (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
657 (Datum::Interval(_), _) => mismatch(datum, scalar_type),
658 (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
659 (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
660 (Datum::String(_), ReprScalarType::String) => Ok(()),
661 (Datum::String(_), _) => mismatch(datum, scalar_type),
662 (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
663 (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
664 (Datum::Array(array), ReprScalarType::Array(t)) => {
665 for e in array.elements().iter() {
666 if let Datum::Null = e {
667 continue;
668 }
669
670 difference_with_scalar_type(&e, t)
671 .map_err(|e| element_type_difference("array", e))?;
672 }
673 Ok(())
674 }
675 (Datum::Array(array), ReprScalarType::Int2Vector) => {
676 if array.dims().len() != 1 {
677 return Err(DatumTypeDifference::MismatchDimensions {
678 ctor: "int2vector".to_string(),
679 got: array.dims().len(),
680 expected: 1,
681 });
682 }
683
684 for e in array.elements().iter() {
685 difference_with_scalar_type(&e, &ReprScalarType::Int16)
686 .map_err(|e| element_type_difference("int2vector", e))?;
687 }
688
689 Ok(())
690 }
691 (Datum::Array(_), _) => mismatch(datum, scalar_type),
692 (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
693 for e in list.iter() {
694 if let Datum::Null = e {
695 continue;
696 }
697
698 difference_with_scalar_type(&e, element_type)
699 .map_err(|e| element_type_difference("list", e))?;
700 }
701 Ok(())
702 }
703 (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
704 let len = list.iter().count();
705 if len != fields.len() {
706 return Err(DatumTypeDifference::MismatchDimensions {
707 ctor: "record".to_string(),
708 got: len,
709 expected: fields.len(),
710 });
711 }
712
713 for (e, t) in list.iter().zip_eq(fields) {
714 if let Datum::Null = e {
715 continue;
716 }
717
718 difference_with_scalar_type(&e, &t.scalar_type)
719 .map_err(|e| element_type_difference("record", e))?;
720 }
721 Ok(())
722 }
723 (Datum::List(_), _) => mismatch(datum, scalar_type),
724 (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
725 for (_, v) in map.iter() {
726 if let Datum::Null = v {
727 continue;
728 }
729
730 difference_with_scalar_type(&v, value_type)
731 .map_err(|e| element_type_difference("map", e))?;
732 }
733 Ok(())
734 }
735 (Datum::Map(_), _) => mismatch(datum, scalar_type),
736 (Datum::JsonNull, _) => mismatch(datum, scalar_type),
737 (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
738 (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
739 (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
740 (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
741 (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
742 match inner {
743 None => Ok(()),
744 Some(inner) => {
745 if let Some(b) = inner.lower.bound {
746 difference_with_scalar_type(&b.datum(), element_type)
747 .map_err(|e| element_type_difference("range", e))?;
748 }
749 if let Some(b) = inner.upper.bound {
750 difference_with_scalar_type(&b.datum(), element_type)
751 .map_err(|e| element_type_difference("range", e))?;
752 }
753 Ok(())
754 }
755 }
756 }
757 (Datum::Range(_), _) => mismatch(datum, scalar_type),
758 (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
759 (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
760 (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
761 (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
762 }
763 }
764 }
765 if column_type.nullable {
766 if let Datum::Null = datum {
767 return Ok(());
768 }
769 }
770 difference_with_scalar_type(datum, &column_type.scalar_type)
771}
772
773fn row_difference_with_column_types<'a>(
774 source: &'a MirRelationExpr,
775 datums: &Vec<Datum<'_>>,
776 column_types: &[ReprColumnType],
777) -> Result<(), TypeError<'a>> {
778 if datums.len() != column_types.len() {
780 return Err(TypeError::BadConstantRowLen {
781 source,
782 got: datums.len(),
783 expected: column_types.to_vec(),
784 });
785 }
786
787 let mut mismatches = Vec::new();
789 for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
790 if let Err(e) = datum_difference_with_column_type(d, ty) {
791 mismatches.push((i, e));
792 }
793 }
794 if !mismatches.is_empty() {
795 return Err(TypeError::BadConstantRow {
796 source,
797 mismatches,
798 expected: column_types.to_vec(),
799 });
800 }
801
802 Ok(())
803}
804#[derive(Debug)]
806pub struct Typecheck {
807 ctx: SharedContext,
809 disallow_new_globals: bool,
811 strict_join_equivalences: bool,
813 disallow_dummy: bool,
815 recursion_guard: RecursionGuard,
817}
818
819impl CheckedRecursion for Typecheck {
820 fn recursion_guard(&self) -> &RecursionGuard {
821 &self.recursion_guard
822 }
823}
824
825impl Typecheck {
826 pub fn new(ctx: SharedContext) -> Self {
828 Self {
829 ctx,
830 disallow_new_globals: false,
831 strict_join_equivalences: false,
832 disallow_dummy: false,
833 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
834 }
835 }
836
837 pub fn disallow_new_globals(mut self) -> Self {
841 self.disallow_new_globals = true;
842 self
843 }
844
845 pub fn strict_join_equivalences(mut self) -> Self {
849 self.strict_join_equivalences = true;
850
851 self
852 }
853
854 pub fn disallow_dummy(mut self) -> Self {
856 self.disallow_dummy = true;
857 self
858 }
859
860 pub fn typecheck<'a>(
871 &self,
872 expr: &'a MirRelationExpr,
873 ctx: &Context,
874 ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
875 use MirRelationExpr::*;
876
877 self.checked_recur(|tc| match expr {
878 Constant { typ, rows } => {
879 if let Ok(rows) = rows {
880 for (row, _id) in rows {
881 let datums = row.unpack();
882
883 row_difference_with_column_types(expr, &datums, &typ.column_types.iter().map(ReprColumnType::from).collect_vec())?;
884
885 if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
886 return Err(TypeError::DisallowedDummy {
887 source: expr,
888 });
889 }
890 }
891 }
892
893 Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
894 }
895 Get { typ, id, .. } => {
896 if let Id::Global(_global_id) = id {
897 if !ctx.contains_key(id) {
898 return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
900 }
901 }
902
903 let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
904 source: expr,
905 id: id.clone(),
906 typ: ReprRelationType::from(typ),
907 })?;
908
909 let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
910
911 let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
913
914 if !diffs.is_empty() {
915 return Err(TypeError::MismatchColumns {
916 source: expr,
917 got: column_types,
918 expected: ctx_typ.clone(),
919 diffs,
920 message: "annotation did not match context type".to_string(),
921 });
922 }
923
924 Ok(column_types)
925 }
926 Project { input, outputs } => {
927 let t_in = tc.typecheck(input, ctx)?;
928
929 for x in outputs {
930 if *x >= t_in.len() {
931 return Err(TypeError::BadProject {
932 source: expr,
933 got: outputs.clone(),
934 input_type: t_in,
935 });
936 }
937 }
938
939 Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
940 }
941 Map { input, scalars } => {
942 let mut t_in = tc.typecheck(input, ctx)?;
943
944 for scalar_expr in scalars.iter() {
945 t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
946
947 if self.disallow_dummy && scalar_expr.contains_dummy() {
948 return Err(TypeError::DisallowedDummy {
949 source: expr,
950 });
951 }
952 }
953
954 Ok(t_in)
955 }
956 FlatMap { input, func, exprs } => {
957 let mut t_in = tc.typecheck(input, ctx)?;
958
959 let mut t_exprs = Vec::with_capacity(exprs.len());
960 for scalar_expr in exprs {
961 t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
962
963 if self.disallow_dummy && scalar_expr.contains_dummy() {
964 return Err(TypeError::DisallowedDummy {
965 source: expr,
966 });
967 }
968 }
969 let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
972
973 t_in.extend(t_out);
975 Ok(t_in)
976 }
977 Filter { input, predicates } => {
978 let mut t_in = tc.typecheck(input, ctx)?;
979
980 for column in non_nullable_columns(predicates) {
983 t_in[column].nullable = false;
984 }
985
986 for scalar_expr in predicates {
987 let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
988
989 if t.scalar_type != ReprScalarType::Bool {
993 let sub = t.scalar_type.clone();
994
995 return Err(TypeError::MismatchColumn {
996 source: expr,
997 got: t,
998 expected: ReprColumnType {
999 scalar_type: ReprScalarType::Bool,
1000 nullable: true,
1001 },
1002 diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
1003 message: "expected boolean condition".to_string(),
1004 });
1005 }
1006
1007 if self.disallow_dummy && scalar_expr.contains_dummy() {
1008 return Err(TypeError::DisallowedDummy {
1009 source: expr,
1010 });
1011 }
1012 }
1013
1014 Ok(t_in)
1015 }
1016 Join {
1017 inputs,
1018 equivalences,
1019 implementation,
1020 } => {
1021 let mut t_in_global = Vec::new();
1022 let mut t_in_local = vec![Vec::new(); inputs.len()];
1023
1024 for (i, input) in inputs.iter().enumerate() {
1025 let input_t = tc.typecheck(input, ctx)?;
1026 t_in_global.extend(input_t.clone());
1027 t_in_local[i] = input_t;
1028 }
1029
1030 for eq_class in equivalences {
1031 let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1032
1033 let mut all_nullable = true;
1034
1035 for scalar_expr in eq_class {
1036 let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1038
1039 if !t_expr.nullable {
1040 all_nullable = false;
1041 }
1042
1043 if let Some(t_first) = t_exprs.get(0) {
1044 let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
1045 if !diffs.is_empty() {
1046 return Err(TypeError::MismatchColumn {
1047 source: expr,
1048 got: t_expr,
1049 expected: t_first.clone(),
1050 diffs,
1051 message: "equivalence class members have different scalar types".to_string(),
1052 });
1053 }
1054
1055 if self.strict_join_equivalences {
1059 if t_expr.nullable != t_first.nullable {
1060 let sub = t_expr.clone();
1061 let sup = t_first.clone();
1062
1063 let err = TypeError::MismatchColumn {
1064 source: expr,
1065 got: t_expr.clone(),
1066 expected: t_first.clone(),
1067 diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
1068 message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
1069 };
1070
1071 ::tracing::debug!("{err}");
1073 }
1074 }
1075 }
1076
1077 if self.disallow_dummy && scalar_expr.contains_dummy() {
1078 return Err(TypeError::DisallowedDummy {
1079 source: expr,
1080 });
1081 }
1082
1083 t_exprs.push(t_expr);
1084 }
1085
1086 if self.strict_join_equivalences && all_nullable {
1087 let err = TypeError::BadJoinEquivalence {
1088 source: expr,
1089 got: t_exprs,
1090 message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1091 };
1092
1093 ::tracing::debug!("{err}");
1095 }
1096 }
1097
1098 match implementation {
1100 JoinImplementation::Differential((start_idx, first_key, _), others) => {
1101 if let Some(key) = first_key {
1102 for k in key {
1103 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1104 }
1105 }
1106
1107 for (idx, key, _) in others {
1108 for k in key {
1109 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1110 }
1111 }
1112 }
1113 JoinImplementation::DeltaQuery(plans) => {
1114 for plan in plans {
1115 for (idx, key, _) in plan {
1116 for k in key {
1117 let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1118 }
1119 }
1120 }
1121 }
1122 JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1123 let typ: Vec<ReprColumnType> = key
1124 .iter()
1125 .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1126 .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1127
1128 for row in consts {
1129 let datums = row.unpack();
1130
1131 row_difference_with_column_types(expr, &datums, &typ)?;
1132 }
1133 }
1134 JoinImplementation::Unimplemented => (),
1135 }
1136
1137 Ok(t_in_global)
1138 }
1139 Reduce {
1140 input,
1141 group_key,
1142 aggregates,
1143 monotonic: _,
1144 expected_group_size: _,
1145 } => {
1146 let t_in = tc.typecheck(input, ctx)?;
1147
1148 let mut t_out = group_key
1149 .iter()
1150 .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1151 .collect::<Result<Vec<_>, _>>()?;
1152
1153 if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
1154 return Err(TypeError::DisallowedDummy {
1155 source: expr,
1156 });
1157 }
1158
1159 for agg in aggregates {
1160 t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1161 }
1162
1163 Ok(t_out)
1164 }
1165 TopK {
1166 input,
1167 group_key,
1168 order_key,
1169 limit: _,
1170 offset: _,
1171 monotonic: _,
1172 expected_group_size: _,
1173 } => {
1174 let t_in = tc.typecheck(input, ctx)?;
1175
1176 for &k in group_key {
1177 if k >= t_in.len() {
1178 return Err(TypeError::BadTopKGroupKey {
1179 source: expr,
1180 k,
1181 input_type: t_in,
1182 });
1183 }
1184 }
1185
1186 for order in order_key {
1187 if order.column >= t_in.len() {
1188 return Err(TypeError::BadTopKOrdering {
1189 source: expr,
1190 order: order.clone(),
1191 input_type: t_in,
1192 });
1193 }
1194 }
1195
1196 Ok(t_in)
1197 }
1198 Negate { input } => tc.typecheck(input, ctx),
1199 Threshold { input } => tc.typecheck(input, ctx),
1200 Union { base, inputs } => {
1201 let mut t_base = tc.typecheck(base, ctx)?;
1202
1203 for input in inputs {
1204 let t_input = tc.typecheck(input, ctx)?;
1205
1206 let len_sub = t_base.len();
1207 let len_sup = t_input.len();
1208 if len_sub != len_sup {
1209 return Err(TypeError::MismatchColumns {
1210 source: expr,
1211 got: t_base.clone(),
1212 expected: t_input,
1213 diffs: vec![ReprRelationTypeDifference::Length {
1214 len_sub,
1215 len_sup,
1216 }],
1217 message: "Union branches have different numbers of columns".to_string(),
1218 });
1219 }
1220
1221 for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1222 let diffs = column_union(base_col, &input_col);
1223 if !diffs.is_empty() {
1224 return Err(TypeError::MismatchColumn {
1225 source: expr,
1226 got: input_col,
1227 expected: base_col.clone(),
1228 diffs,
1229 message:
1230 "couldn't compute union of column types in Union"
1231 .to_string(),
1232 });
1233 }
1234
1235 }
1236 }
1237
1238 Ok(t_base)
1239 }
1240 Let { id, value, body } => {
1241 let t_value = tc.typecheck(value, ctx)?;
1242
1243 let binding = Id::Local(*id);
1244 if ctx.contains_key(&binding) {
1245 return Err(TypeError::Shadowing {
1246 source: expr,
1247 id: binding,
1248 });
1249 }
1250
1251 let mut body_ctx = ctx.clone();
1252 body_ctx.insert(Id::Local(*id), t_value);
1253
1254 tc.typecheck(body, &body_ctx)
1255 }
1256 LetRec { ids, values, body, limits: _ } => {
1257 if ids.len() != values.len() {
1258 return Err(TypeError::BadLetRecBindings { source: expr });
1259 }
1260
1261 let mut ctx = ctx.clone();
1264 for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1266 tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1267 }
1268
1269 for (id, value) in ids.iter().zip_eq(values.iter()) {
1270 let typ = tc.typecheck(value, &ctx)?;
1271
1272 let id = Id::Local(id.clone());
1273 if let Some(ctx_typ) = ctx.get_mut(&id) {
1274 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1275 let diffs = column_union(base_col, &input_col);
1277 if !diffs.is_empty() {
1278 return Err(TypeError::MismatchColumn {
1279 source: expr,
1280 got: input_col,
1281 expected: base_col.clone(),
1282 diffs,
1283 message:
1284 "couldn't compute union of column types in LetRec"
1285 .to_string(),
1286 })
1287 }
1288 }
1289 } else {
1290 ctx.insert(id, typ);
1292 }
1293 }
1294
1295 tc.typecheck(body, &ctx)
1296 }
1297 ArrangeBy { input, keys } => {
1298 let t_in = tc.typecheck(input, ctx)?;
1299
1300 for key in keys {
1301 for k in key {
1302 let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1303 }
1304 }
1305
1306 Ok(t_in)
1307 }
1308 })
1309 }
1310
1311 fn collect_recursive_variable_types<'a>(
1315 &self,
1316 expr: &'a MirRelationExpr,
1317 ids: &[LocalId],
1318 ctx: &mut Context,
1319 ) -> Result<(), TypeError<'a>> {
1320 use MirRelationExpr::*;
1321
1322 self.checked_recur(|tc| {
1323 match expr {
1324 Get {
1325 id: Id::Local(id),
1326 typ,
1327 ..
1328 } => {
1329 if !ids.contains(id) {
1330 return Ok(());
1331 }
1332
1333 let id = Id::Local(id.clone());
1334 if let Some(ctx_typ) = ctx.get_mut(&id) {
1335 let typ = typ
1336 .column_types
1337 .iter()
1338 .map(ReprColumnType::from)
1339 .collect_vec();
1340
1341 if ctx_typ.len() != typ.len() {
1342 let diffs = relation_subtype_difference(&typ, ctx_typ);
1343
1344 return Err(TypeError::MismatchColumns {
1345 source: expr,
1346 got: typ,
1347 expected: ctx_typ.clone(),
1348 diffs,
1349 message: "environment and type annotation did not match"
1350 .to_string(),
1351 });
1352 }
1353
1354 for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1355 let diffs = column_union(base_col, &input_col);
1356 if !diffs.is_empty() {
1357 return Err(TypeError::MismatchColumn {
1358 source: expr,
1359 got: input_col,
1360 expected: base_col.clone(),
1361 diffs,
1362 message:
1363 "couldn't compute union of column types in Get and context"
1364 .to_string(),
1365 });
1366 }
1367 }
1368 } else {
1369 ctx.insert(
1370 id,
1371 typ.column_types
1372 .iter()
1373 .map(ReprColumnType::from)
1374 .collect_vec(),
1375 );
1376 }
1377 }
1378 Get {
1379 id: Id::Global(..), ..
1380 }
1381 | Constant { .. } => (),
1382 Let { id, value, body } => {
1383 tc.collect_recursive_variable_types(value, ids, ctx)?;
1384
1385 if ids.contains(id) {
1387 return Err(TypeError::Shadowing {
1388 source: expr,
1389 id: Id::Local(*id),
1390 });
1391 }
1392
1393 tc.collect_recursive_variable_types(body, ids, ctx)?;
1394 }
1395 LetRec {
1396 ids: inner_ids,
1397 values,
1398 body,
1399 limits: _,
1400 } => {
1401 for inner_id in inner_ids {
1402 if ids.contains(inner_id) {
1403 return Err(TypeError::Shadowing {
1404 source: expr,
1405 id: Id::Local(*inner_id),
1406 });
1407 }
1408 }
1409
1410 for value in values {
1411 tc.collect_recursive_variable_types(value, ids, ctx)?;
1412 }
1413
1414 tc.collect_recursive_variable_types(body, ids, ctx)?;
1415 }
1416 Project { input, .. }
1417 | Map { input, .. }
1418 | FlatMap { input, .. }
1419 | Filter { input, .. }
1420 | Reduce { input, .. }
1421 | TopK { input, .. }
1422 | Negate { input }
1423 | Threshold { input }
1424 | ArrangeBy { input, .. } => {
1425 tc.collect_recursive_variable_types(input, ids, ctx)?;
1426 }
1427 Join { inputs, .. } => {
1428 for input in inputs {
1429 tc.collect_recursive_variable_types(input, ids, ctx)?;
1430 }
1431 }
1432 Union { base, inputs } => {
1433 tc.collect_recursive_variable_types(base, ids, ctx)?;
1434
1435 for input in inputs {
1436 tc.collect_recursive_variable_types(input, ids, ctx)?;
1437 }
1438 }
1439 }
1440
1441 Ok(())
1442 })
1443 }
1444
1445 fn typecheck_scalar<'a>(
1446 &self,
1447 expr: &'a MirScalarExpr,
1448 source: &'a MirRelationExpr,
1449 column_types: &[ReprColumnType],
1450 ) -> Result<ReprColumnType, TypeError<'a>> {
1451 use MirScalarExpr::*;
1452
1453 self.checked_recur(|tc| match expr {
1454 Column(i, _) => match column_types.get(*i) {
1455 Some(ty) => Ok(ty.clone()),
1456 None => Err(TypeError::NoSuchColumn {
1457 source,
1458 expr,
1459 col: *i,
1460 }),
1461 },
1462 Literal(row, typ) => {
1463 let typ = ReprColumnType::from(typ);
1464 if let Ok(row) = row {
1465 let datums = row.unpack();
1466
1467 row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1468 }
1469
1470 Ok(typ)
1471 }
1472 CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1473 CallUnary { expr, func } => {
1474 let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1475 let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1476 Ok(ReprColumnType::from(&typ_out))
1477 }
1478 CallBinary { expr1, expr2, func } => {
1479 let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1480 let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1481 let typ_out = func.output_type(
1482 SqlColumnType::from_repr(&typ_in1),
1483 SqlColumnType::from_repr(&typ_in2),
1484 );
1485 Ok(ReprColumnType::from(&typ_out))
1486 }
1487 CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1488 &func.output_type(
1489 exprs
1490 .iter()
1491 .map(|e| {
1492 tc.typecheck_scalar(e, source, column_types)
1493 .map(|typ| SqlColumnType::from_repr(&typ))
1494 })
1495 .collect::<Result<Vec<_>, TypeError>>()?,
1496 ),
1497 )),
1498 If { cond, then, els } => {
1499 let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1500
1501 if cond_type.scalar_type != ReprScalarType::Bool {
1505 let sub = cond_type.scalar_type.clone();
1506
1507 return Err(TypeError::MismatchColumn {
1508 source,
1509 got: cond_type,
1510 expected: ReprColumnType {
1511 scalar_type: ReprScalarType::Bool,
1512 nullable: true,
1513 },
1514 diffs: vec![ReprColumnTypeDifference::NotSubtype {
1515 sub,
1516 sup: ReprScalarType::Bool,
1517 }],
1518 message: "expected boolean condition".to_string(),
1519 });
1520 }
1521
1522 let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1523 let else_type = tc.typecheck_scalar(els, source, column_types)?;
1524
1525 let diffs = column_union(&mut then_type, &else_type);
1526 if !diffs.is_empty() {
1527 return Err(TypeError::MismatchColumn {
1528 source,
1529 got: then_type,
1530 expected: else_type,
1531 diffs,
1532 message: "couldn't compute union of column types for If".to_string(),
1533 });
1534 }
1535
1536 Ok(then_type)
1537 }
1538 })
1539 }
1540
1541 pub fn typecheck_aggregate<'a>(
1543 &self,
1544 expr: &'a AggregateExpr,
1545 source: &'a MirRelationExpr,
1546 column_types: &[ReprColumnType],
1547 ) -> Result<ReprColumnType, TypeError<'a>> {
1548 self.checked_recur(|tc| {
1549 let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1550
1551 Ok(ReprColumnType::from(
1554 &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1555 ))
1556 })
1557 }
1558}
1559
1560macro_rules! type_error {
1564 ($severity:expr, $($arg:tt)+) => {{
1565 if $severity {
1566 soft_panic_or_log!($($arg)+);
1567 } else {
1568 ::tracing::debug!($($arg)+);
1569 }
1570 }}
1571}
1572
1573impl crate::Transform for Typecheck {
1574 fn name(&self) -> &'static str {
1575 "Typecheck"
1576 }
1577
1578 fn actually_perform_transform(
1579 &self,
1580 relation: &mut MirRelationExpr,
1581 transform_ctx: &mut crate::TransformCtx,
1582 ) -> Result<(), crate::TransformError> {
1583 let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1584
1585 let expected = transform_ctx
1586 .global_id
1587 .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1588
1589 if let Some(id) = transform_ctx.global_id {
1590 if self.disallow_new_globals
1591 && expected.is_none()
1592 && transform_ctx.global_id.is_some()
1593 && !id.is_transient()
1594 {
1595 type_error!(
1596 false, "type warning: new non-transient global id {id}\n{}",
1598 relation.pretty()
1599 );
1600 }
1601 }
1602
1603 let got = self.typecheck(relation, &typecheck_ctx);
1604
1605 let humanizer = mz_repr::explain::DummyHumanizer;
1606
1607 match (got, expected) {
1608 (Ok(got), Some(expected)) => {
1609 let id = transform_ctx.global_id.unwrap();
1610
1611 let diffs = relation_subtype_difference(expected, &got);
1613 if !diffs.is_empty() {
1614 let severity = diffs
1616 .iter()
1617 .any(|diff| diff.clone().ignore_nullability().is_some());
1618
1619 let err = TypeError::MismatchColumns {
1620 source: relation,
1621 got,
1622 expected: expected.clone(),
1623 diffs,
1624 message: format!(
1625 "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1626 ),
1627 };
1628
1629 type_error!(severity, "type error in known global id {id}:\n{err}");
1630 }
1631 }
1632 (Ok(got), None) => {
1633 if let Some(id) = transform_ctx.global_id {
1634 typecheck_ctx.insert(Id::Global(id), got);
1635 }
1636 }
1637 (Err(err), _) => {
1638 let (expected, binding) = match expected {
1639 Some(expected) => {
1640 let id = transform_ctx.global_id.unwrap();
1641 (
1642 format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1643 format!("known global id {id}"),
1644 )
1645 }
1646 None => ("".to_string(), "transient query".to_string()),
1647 };
1648
1649 type_error!(
1650 true, "type error in {binding}:\n{err}\n{expected}{}",
1652 relation.pretty()
1653 );
1654 }
1655 }
1656
1657 Ok(())
1658 }
1659}
1660
1661pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1663where
1664 H: ExprHumanizer,
1665{
1666 let mut s = String::with_capacity(2 + 3 * cols.len());
1667
1668 s.push('(');
1669
1670 let mut it = cols.iter().peekable();
1671 while let Some(col) = it.next() {
1672 s.push_str(&humanizer.humanize_column_type_repr(col, false));
1673
1674 if it.peek().is_some() {
1675 s.push_str(", ");
1676 }
1677 }
1678
1679 s.push(')');
1680
1681 s
1682}
1683
1684impl ReprRelationTypeDifference {
1685 pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1689 where
1690 H: ExprHumanizer,
1691 {
1692 use ReprRelationTypeDifference::*;
1693 match self {
1694 Length { len_sub, len_sup } => {
1695 writeln!(
1696 f,
1697 " number of columns do not match ({len_sub} != {len_sup})"
1698 )
1699 }
1700 Column { col, diff } => {
1701 writeln!(f, " column {col} differs:")?;
1702 diff.humanize(4, h, f)
1703 }
1704 }
1705 }
1706}
1707
1708impl ReprColumnTypeDifference {
1709 pub fn humanize<H>(
1711 &self,
1712 indent: usize,
1713 h: &H,
1714 f: &mut std::fmt::Formatter<'_>,
1715 ) -> std::fmt::Result
1716 where
1717 H: ExprHumanizer,
1718 {
1719 use ReprColumnTypeDifference::*;
1720
1721 write!(f, "{:indent$}", "")?;
1723
1724 match self {
1725 NotSubtype { sub, sup } => {
1726 let sub = h.humanize_scalar_type_repr(sub, false);
1727 let sup = h.humanize_scalar_type_repr(sup, false);
1728
1729 writeln!(f, "{sub} is a not a subtype of {sup}")
1730 }
1731 Nullability { sub, sup } => {
1732 let sub = h.humanize_column_type_repr(sub, false);
1733 let sup = h.humanize_column_type_repr(sup, false);
1734
1735 writeln!(f, "{sub} is nullable but {sup} is not")
1736 }
1737 ElementType { ctor, element_type } => {
1738 writeln!(f, "{ctor} element types differ:")?;
1739
1740 element_type.humanize(indent + 2, h, f)
1741 }
1742 RecordMissingFields { missing } => {
1743 write!(f, "missing column fields:")?;
1744 for col in missing {
1745 write!(f, " {col}")?;
1746 }
1747 f.write_char('\n')
1748 }
1749 RecordFields { fields } => {
1750 writeln!(f, "{} record fields differ:", fields.len())?;
1751
1752 for (i, diff) in fields.iter().enumerate() {
1753 writeln!(f, "{:indent$} field {i}:", "")?;
1754 diff.humanize(indent + 4, h, f)?;
1755 }
1756 Ok(())
1757 }
1758 }
1759 }
1760}
1761
1762impl DatumTypeDifference {
1763 pub fn humanize<H>(
1765 &self,
1766 indent: usize,
1767 h: &H,
1768 f: &mut std::fmt::Formatter<'_>,
1769 ) -> std::fmt::Result
1770 where
1771 H: ExprHumanizer,
1772 {
1773 write!(f, "{:indent$}", "")?;
1775
1776 match self {
1777 DatumTypeDifference::Null => writeln!(f, "unexpected null")?,
1778 DatumTypeDifference::Mismatch {
1779 got_debug,
1780 expected,
1781 } => {
1782 let expected = h.humanize_scalar_type_repr(expected, false);
1783 writeln!(
1785 f,
1786 "got datum {got_debug}, expected representation type {expected}"
1787 )?;
1788 }
1789 DatumTypeDifference::MismatchDimensions {
1790 ctor,
1791 got,
1792 expected,
1793 } => {
1794 writeln!(
1795 f,
1796 "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1797 )?;
1798 }
1799 DatumTypeDifference::ElementType { ctor, element_type } => {
1800 writeln!(f, "{ctor} element types differ:")?;
1801 element_type.humanize(indent + 4, h, f)?;
1802 }
1803 }
1804
1805 Ok(())
1806 }
1807}
1808
1809#[allow(missing_debug_implementations)]
1811pub struct TypeErrorHumanizer<'a, 'b, H>
1812where
1813 H: ExprHumanizer,
1814{
1815 err: &'a TypeError<'a>,
1816 humanizer: &'b H,
1817}
1818
1819impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1820where
1821 H: ExprHumanizer,
1822{
1823 pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1825 Self { err, humanizer }
1826 }
1827}
1828
1829impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1830where
1831 H: ExprHumanizer,
1832{
1833 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1834 self.err.humanize(self.humanizer, f)
1835 }
1836}
1837
1838impl<'a> std::fmt::Display for TypeError<'a> {
1839 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840 TypeErrorHumanizer {
1841 err: self,
1842 humanizer: &DummyHumanizer,
1843 }
1844 .fmt(f)
1845 }
1846}
1847
1848impl<'a> TypeError<'a> {
1849 pub fn source(&self) -> Option<&'a MirRelationExpr> {
1851 use TypeError::*;
1852 match self {
1853 Unbound { source, .. }
1854 | NoSuchColumn { source, .. }
1855 | MismatchColumn { source, .. }
1856 | MismatchColumns { source, .. }
1857 | BadConstantRowLen { source, .. }
1858 | BadConstantRow { source, .. }
1859 | BadProject { source, .. }
1860 | BadJoinEquivalence { source, .. }
1861 | BadTopKGroupKey { source, .. }
1862 | BadTopKOrdering { source, .. }
1863 | BadLetRecBindings { source }
1864 | Shadowing { source, .. }
1865 | DisallowedDummy { source, .. } => Some(source),
1866 Recursion { .. } => None,
1867 }
1868 }
1869
1870 fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1871 where
1872 H: ExprHumanizer,
1873 {
1874 if let Some(source) = self.source() {
1875 writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1876 }
1877
1878 use TypeError::*;
1879 match self {
1880 Unbound { source: _, id, typ } => {
1881 let typ = columns_pretty(&typ.column_types, humanizer);
1882 writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1883 }
1884 NoSuchColumn {
1885 source: _,
1886 expr,
1887 col,
1888 } => writeln!(f, "{expr} references non-existent column {col}")?,
1889 MismatchColumn {
1890 source: _,
1891 got,
1892 expected,
1893 diffs,
1894 message,
1895 } => {
1896 let got = humanizer.humanize_column_type_repr(got, false);
1897 let expected = humanizer.humanize_column_type_repr(expected, false);
1898 writeln!(
1899 f,
1900 "mismatched column types: {message}\n got {got}\nexpected {expected}"
1901 )?;
1902
1903 for diff in diffs {
1904 diff.humanize(2, humanizer, f)?;
1905 }
1906 }
1907 MismatchColumns {
1908 source: _,
1909 got,
1910 expected,
1911 diffs,
1912 message,
1913 } => {
1914 let got = columns_pretty(got, humanizer);
1915 let expected = columns_pretty(expected, humanizer);
1916
1917 writeln!(
1918 f,
1919 "mismatched relation types: {message}\n got {got}\nexpected {expected}"
1920 )?;
1921
1922 for diff in diffs {
1923 diff.humanize(humanizer, f)?;
1924 }
1925 }
1926 BadConstantRowLen {
1927 source: _,
1928 got,
1929 expected,
1930 } => {
1931 let expected = columns_pretty(expected, humanizer);
1932 writeln!(
1933 f,
1934 "bad constant row\n row has length {got}\nexpected row of type {expected}"
1935 )?
1936 }
1937 BadConstantRow {
1938 source: _,
1939 mismatches,
1940 expected,
1941 } => {
1942 let expected = columns_pretty(expected, humanizer);
1943
1944 let num_mismatches = mismatches.len();
1945 let plural = if num_mismatches == 1 { "" } else { "es" };
1946 writeln!(
1947 f,
1948 "bad constant row\n got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1949 )?;
1950
1951 if num_mismatches > 0 {
1952 writeln!(f, "")?;
1953 for (col, diff) in mismatches.iter() {
1954 writeln!(f, " column #{col}:")?;
1955 diff.humanize(8, humanizer, f)?;
1956 }
1957 }
1958 }
1959 BadProject {
1960 source: _,
1961 got,
1962 input_type,
1963 } => {
1964 let input_type = columns_pretty(input_type, humanizer);
1965
1966 writeln!(
1967 f,
1968 "projection of non-existant columns {got:?} from type {input_type}"
1969 )?
1970 }
1971 BadJoinEquivalence {
1972 source: _,
1973 got,
1974 message,
1975 } => {
1976 let got = columns_pretty(got, humanizer);
1977
1978 writeln!(f, "bad join equivalence {got}: {message}")?
1979 }
1980 BadTopKGroupKey {
1981 source: _,
1982 k,
1983 input_type,
1984 } => {
1985 let input_type = columns_pretty(input_type, humanizer);
1986
1987 writeln!(
1988 f,
1989 "TopK group key component references invalid column {k} in columns: {input_type}"
1990 )?
1991 }
1992 BadTopKOrdering {
1993 source: _,
1994 order,
1995 input_type,
1996 } => {
1997 let col = order.column;
1998 let num_cols = input_type.len();
1999 let are = if num_cols == 1 { "is" } else { "are" };
2000 let s = if num_cols == 1 { "" } else { "s" };
2001 let input_type = columns_pretty(input_type, humanizer);
2002
2003 let mode = HumanizedExplain::new(false);
2005 let order = mode.expr(order, None);
2006
2007 writeln!(
2008 f,
2009 "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2010 )?
2011 }
2012 BadLetRecBindings { source: _ } => {
2013 writeln!(f, "LetRec ids and definitions don't line up")?
2014 }
2015 Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2016 DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2017 Recursion { error } => writeln!(f, "{error}")?,
2018 }
2019
2020 Ok(())
2021 }
2022}
2023
2024#[cfg(test)]
2025mod tests {
2026 use mz_ore::{assert_err, assert_ok};
2027 use mz_repr::{arb_datum, arb_datum_for_column};
2028 use proptest::prelude::*;
2029
2030 use super::*;
2031
2032 #[mz_ore::test]
2033 fn test_datum_type_difference() {
2034 let datum = Datum::Int16(1);
2035
2036 assert_ok!(datum_difference_with_column_type(
2037 &datum,
2038 &ReprColumnType {
2039 scalar_type: ReprScalarType::Int16,
2040 nullable: true,
2041 }
2042 ));
2043
2044 assert_err!(datum_difference_with_column_type(
2045 &datum,
2046 &ReprColumnType {
2047 scalar_type: ReprScalarType::Int32,
2048 nullable: false,
2049 }
2050 ));
2051 }
2052
2053 proptest! {
2054 #![proptest_config(ProptestConfig { cases: 5000, max_global_rejects: 2500, ..Default::default() })]
2055 #[mz_ore::test]
2056 #[cfg_attr(miri, ignore)]
2057 fn datum_type_difference_with_instance_of_on_valid_data((src, datum) in any::<SqlColumnType>().prop_flat_map(|src| {
2058 let datum = arb_datum_for_column(src.clone());
2059 (Just(src), datum) }
2060 )) {
2061 let typ = ReprColumnType::from(&src);
2062 let datum = Datum::from(&datum);
2063
2064 if datum.contains_dummy() {
2065 return Err(TestCaseError::reject("datum contains a dummy"));
2066 }
2067
2068 let diff = datum_difference_with_column_type(&datum, &typ);
2069 if datum.is_instance_of(&typ) {
2070 assert_ok!(diff);
2071 } else {
2072 assert_err!(diff);
2073 }
2074 }
2075 }
2076
2077 proptest! {
2078 #![proptest_config(ProptestConfig::with_cases(10000))]
2081 #[mz_ore::test]
2082 #[cfg_attr(miri, ignore)]
2083 fn datum_type_difference_agrees_with_is_instance_of_on_random_data(src in any::<SqlColumnType>(), datum in arb_datum(false)) {
2084 let typ = ReprColumnType::from(&src);
2085 let datum = Datum::from(&datum);
2086
2087 assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2088
2089 let diff = datum_difference_with_column_type(&datum, &typ);
2090 if datum.is_instance_of(&typ) {
2091 assert_ok!(diff);
2092 } else {
2093 assert_err!(diff);
2094 }
2095 }
2096 }
2097}