mz_transform/
reprtypecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
25use mz_repr::{
26    ColumnName, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType, Row,
27    SqlColumnType,
28};
29
30/// Typechecking contexts as shared by various typechecking passes.
31///
32/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
33/// Shared contexts help catch consistency issues.
34pub type SharedContext = Arc<Mutex<Context>>;
35
36/// Generates an empty context
37pub fn empty_context() -> SharedContext {
38    Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41/// The possible forms of inconsistency/errors discovered during typechecking.
42///
43/// Every variant has a `source` field identifying the MIR term that is home
44/// to the error (though not necessarily the root cause of the error).
45#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47    /// Unbound identifiers (local or global)
48    Unbound {
49        /// Expression with the bug
50        source: &'a MirRelationExpr,
51        /// The (unbound) identifier referenced
52        id: Id,
53        /// The type `id` was expected to have
54        typ: ReprRelationType,
55    },
56    /// Dereference of a non-existent column
57    NoSuchColumn {
58        /// Expression with the bug
59        source: &'a MirRelationExpr,
60        /// Scalar expression that references an invalid column
61        expr: &'a MirScalarExpr,
62        /// The invalid column referenced
63        col: usize,
64    },
65    /// A single column type does not match
66    MismatchColumn {
67        /// Expression with the bug
68        source: &'a MirRelationExpr,
69        /// The column type we found (`sub` type)
70        got: ReprColumnType,
71        /// The column type we expected (`sup` type)
72        expected: ReprColumnType,
73        /// The difference between these types
74        diffs: Vec<ReprColumnTypeDifference>,
75        /// An explanatory message
76        message: String,
77    },
78    /// Relation column types do not match
79    MismatchColumns {
80        /// Expression with the bug
81        source: &'a MirRelationExpr,
82        /// The column types we found (`sub` type)
83        got: Vec<ReprColumnType>,
84        /// The column types we expected (`sup` type)
85        expected: Vec<ReprColumnType>,
86        /// The difference between these types
87        diffs: Vec<ReprRelationTypeDifference>,
88        /// An explanatory message
89        message: String,
90    },
91    /// A constant row does not have the correct type
92    BadConstantRow {
93        /// Expression with the bug
94        source: &'a MirRelationExpr,
95        /// A constant row
96        got: Row,
97        /// The expected type (which that row does not have)
98        expected: Vec<ReprColumnType>,
99        // TODO(mgree) with a good way to get the type of a Datum, we could give a diff here
100    },
101    /// Projection of a non-existent column
102    BadProject {
103        /// Expression with the bug
104        source: &'a MirRelationExpr,
105        /// The column projected
106        got: Vec<usize>,
107        /// The input columns (which don't have that column)
108        input_type: Vec<ReprColumnType>,
109    },
110    /// An equivalence class in a join was malformed
111    BadJoinEquivalence {
112        /// Expression with the bug
113        source: &'a MirRelationExpr,
114        /// The join equivalences
115        got: Vec<ReprColumnType>,
116        /// The problem with the join equivalences
117        message: String,
118    },
119    /// TopK grouping by non-existent column
120    BadTopKGroupKey {
121        /// Expression with the bug
122        source: &'a MirRelationExpr,
123        /// The bad column reference in the group key
124        k: usize,
125        /// The input columns (which don't have that column)
126        input_type: Vec<ReprColumnType>,
127    },
128    /// TopK ordering by non-existent column
129    BadTopKOrdering {
130        /// Expression with the bug
131        source: &'a MirRelationExpr,
132        /// The ordering used
133        order: ColumnOrder,
134        /// The input columns (which don't work for that ordering)
135        input_type: Vec<ReprColumnType>,
136    },
137    /// LetRec bindings are malformed
138    BadLetRecBindings {
139        /// Expression with the bug
140        source: &'a MirRelationExpr,
141    },
142    /// Local identifiers are shadowed
143    Shadowing {
144        /// Expression with the bug
145        source: &'a MirRelationExpr,
146        /// The id that was shadowed
147        id: Id,
148    },
149    /// Recursion depth exceeded
150    Recursion {
151        /// The error that aborted recursion
152        error: RecursionLimitError,
153    },
154    /// A dummy value was found
155    DisallowedDummy {
156        /// The expression with the dummy value
157        source: &'a MirRelationExpr,
158    },
159}
160
161impl<'a> From<RecursionLimitError> for TypeError<'a> {
162    fn from(error: RecursionLimitError) -> Self {
163        TypeError::Recursion { error }
164    }
165}
166
167type Context = BTreeMap<Id, Vec<ReprColumnType>>;
168
169/// Characterizes differences between relation types
170///
171/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
172#[derive(Clone, Debug, Hash)]
173pub enum ReprRelationTypeDifference {
174    /// `sub` and `sup` don't have the same number of columns
175    Length {
176        /// Length of `sub`
177        len_sub: usize,
178        /// Length of `sup`
179        len_sup: usize,
180    },
181    /// `sub` and `sup` differ at the indicated column
182    Column {
183        /// The column at which `sub` and `sup` differ
184        col: usize,
185        /// The difference between `sub` and `sup`
186        diff: ReprColumnTypeDifference,
187    },
188}
189
190/// Characterizes differences between individual column types
191///
192/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
193/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
194#[derive(Clone, Debug, Hash)]
195pub enum ReprColumnTypeDifference {
196    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
197    NotSubtype {
198        /// Would-be subtype
199        sub: ReprScalarType,
200        /// Would-be supertype
201        sup: ReprScalarType,
202    },
203    /// `sub` was nullable but `sup` was not
204    Nullability {
205        /// Would-be subtype
206        sub: ReprColumnType,
207        /// Would-be supertype
208        sup: ReprColumnType,
209    },
210    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
211    ElementType {
212        /// The type constructor (list, array, etc.)
213        ctor: String,
214        /// The difference in the element type
215        element_type: Box<ReprColumnTypeDifference>,
216    },
217    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
218    RecordMissingFields {
219        /// The missing fields
220        missing: Vec<ColumnName>,
221    },
222    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
223    RecordFields {
224        /// The differences, by field
225        fields: Vec<ReprColumnTypeDifference>,
226    },
227}
228
229impl ReprRelationTypeDifference {
230    /// Returns the same type difference, but ignoring nullability
231    ///
232    /// Returns `None` when _all_ of the differences are due to nullability
233    pub fn ignore_nullability(self) -> Option<Self> {
234        use ReprRelationTypeDifference::*;
235
236        match self {
237            Length { .. } => Some(self),
238            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
239        }
240    }
241}
242
243impl ReprColumnTypeDifference {
244    /// Returns the same type difference, but ignoring nullability
245    ///
246    /// Returns `None` when _all_ of the differences are due to nullability
247    pub fn ignore_nullability(self) -> Option<Self> {
248        use ReprColumnTypeDifference::*;
249
250        match self {
251            Nullability { .. } => None,
252            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
253            ElementType { ctor, element_type } => {
254                element_type
255                    .ignore_nullability()
256                    .map(|element_type| ElementType {
257                        ctor,
258                        element_type: Box::new(element_type),
259                    })
260            }
261            RecordFields { fields } => {
262                let fields = fields
263                    .into_iter()
264                    .flat_map(|diff| diff.ignore_nullability())
265                    .collect::<Vec<_>>();
266
267                if fields.is_empty() {
268                    None
269                } else {
270                    Some(RecordFields { fields })
271                }
272            }
273        }
274    }
275}
276
277/// Returns a list of differences that make `sub` not a subtype of `sup`
278///
279/// This function returns an empty list when `sub` is a subtype of `sup`
280pub fn relation_subtype_difference(
281    sub: &[ReprColumnType],
282    sup: &[ReprColumnType],
283) -> Vec<ReprRelationTypeDifference> {
284    let mut diffs = Vec::new();
285
286    if sub.len() != sup.len() {
287        diffs.push(ReprRelationTypeDifference::Length {
288            len_sub: sub.len(),
289            len_sup: sup.len(),
290        });
291
292        // TODO(mgree) we could do an edit-distance computation to report more errors
293        return diffs;
294    }
295
296    diffs.extend(
297        sub.iter()
298            .zip_eq(sup.iter())
299            .enumerate()
300            .flat_map(|(col, (sub_ty, sup_ty))| {
301                column_subtype_difference(sub_ty, sup_ty)
302                    .into_iter()
303                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
304            }),
305    );
306
307    diffs
308}
309
310/// Returns a list of differences that make `sub` not a subtype of `sup`
311///
312/// This function returns an empty list when `sub` is a subtype of `sup`
313pub fn column_subtype_difference(
314    sub: &ReprColumnType,
315    sup: &ReprColumnType,
316) -> Vec<ReprColumnTypeDifference> {
317    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
318
319    if sub.nullable && !sup.nullable {
320        diffs.push(ReprColumnTypeDifference::Nullability {
321            sub: sub.clone(),
322            sup: sup.clone(),
323        });
324    }
325
326    diffs
327}
328
329/// Returns a list of differences that make `sub` not a subtype of `sup`
330///
331/// This function returns an empty list when `sub` is a subtype of `sup`
332pub fn scalar_subtype_difference(
333    sub: &ReprScalarType,
334    sup: &ReprScalarType,
335) -> Vec<ReprColumnTypeDifference> {
336    use ReprScalarType::*;
337
338    let mut diffs = Vec::new();
339
340    match (sub, sup) {
341        (
342            List {
343                element_type: sub_elt,
344                ..
345            },
346            List {
347                element_type: sup_elt,
348                ..
349            },
350        )
351        | (
352            Map {
353                value_type: sub_elt,
354                ..
355            },
356            Map {
357                value_type: sup_elt,
358                ..
359            },
360        )
361        | (
362            Range {
363                element_type: sub_elt,
364                ..
365            },
366            Range {
367                element_type: sup_elt,
368                ..
369            },
370        )
371        | (Array(sub_elt), Array(sup_elt)) => {
372            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
373            diffs.extend(
374                scalar_subtype_difference(sub_elt, sup_elt)
375                    .into_iter()
376                    .map(|diff| ReprColumnTypeDifference::ElementType {
377                        ctor: ctor.clone(),
378                        element_type: Box::new(diff),
379                    }),
380            );
381        }
382        (
383            Record {
384                fields: sub_fields, ..
385            },
386            Record {
387                fields: sup_fields, ..
388            },
389        ) => {
390            if sub_fields.len() != sup_fields.len() {
391                diffs.push(ReprColumnTypeDifference::NotSubtype {
392                    sub: sub.clone(),
393                    sup: sup.clone(),
394                });
395                return diffs;
396            }
397
398            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
399                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
400            }
401        }
402        (_, _) => {
403            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
404                diffs.push(ReprColumnTypeDifference::NotSubtype {
405                    sub: sub.clone(),
406                    sup: sup.clone(),
407                })
408            }
409        }
410    };
411
412    diffs
413}
414
415/// Unions `other` into `typ`, returning a list of differences on failure
416///
417/// This function returns an empty list when `typ` and `other` are a union
418pub fn scalar_union(
419    typ: &mut ReprScalarType,
420    other: &ReprScalarType,
421) -> Vec<ReprColumnTypeDifference> {
422    use ReprScalarType::*;
423
424    let mut diffs = Vec::new();
425
426    // precomputing to appease the borrow checker
427    let ctor = ReprScalarBaseType::from(&*typ);
428    match (typ, other) {
429        (
430            List {
431                element_type: typ_elt,
432            },
433            List {
434                element_type: other_elt,
435            },
436        )
437        | (
438            Map {
439                value_type: typ_elt,
440            },
441            Map {
442                value_type: other_elt,
443            },
444        )
445        | (
446            Range {
447                element_type: typ_elt,
448            },
449            Range {
450                element_type: other_elt,
451            },
452        )
453        | (Array(typ_elt), Array(other_elt)) => {
454            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
455            diffs.extend(
456                res.into_iter()
457                    .map(|diff| ReprColumnTypeDifference::ElementType {
458                        ctor: format!("{ctor:?}"),
459                        element_type: Box::new(diff),
460                    }),
461            );
462        }
463        (
464            Record { fields: typ_fields },
465            Record {
466                fields: other_fields,
467            },
468        ) => {
469            if typ_fields.len() != other_fields.len() {
470                diffs.push(ReprColumnTypeDifference::NotSubtype {
471                    sub: ReprScalarType::Record {
472                        fields: typ_fields.clone(),
473                    },
474                    sup: other.clone(),
475                });
476                return diffs;
477            }
478
479            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
480                diffs.extend(column_union(typ_ty, other_ty));
481            }
482        }
483        (typ, _) => {
484            if ctor != ReprScalarBaseType::from(other) {
485                diffs.push(ReprColumnTypeDifference::NotSubtype {
486                    sub: typ.clone(),
487                    sup: other.clone(),
488                })
489            }
490        }
491    };
492
493    diffs
494}
495
496/// Unions `other` into `typ`, returning a list of differences on failure
497///
498/// This function returns an empty list when `typ` and `other` are a union
499pub fn column_union(
500    typ: &mut ReprColumnType,
501    other: &ReprColumnType,
502) -> Vec<ReprColumnTypeDifference> {
503    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
504
505    if diffs.is_empty() {
506        typ.nullable |= other.nullable;
507    }
508
509    diffs
510}
511
512/// Returns true when it is safe to treat a `sub` row as an `sup` row
513///
514/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
515/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
516pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
517    if sub.len() != sup.len() {
518        return false;
519    }
520
521    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
522        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
523    })
524}
525
526/// Check that the visible type of each query has not been changed
527#[derive(Debug)]
528pub struct Typecheck {
529    /// The known types of the queries so far
530    ctx: SharedContext,
531    /// Whether or not this is the first run of the transform
532    disallow_new_globals: bool,
533    /// Whether or not to be strict about join equivalences having the same nullability
534    strict_join_equivalences: bool,
535    /// Whether or not to disallow dummy values
536    disallow_dummy: bool,
537    /// Recursion guard for checked recursion
538    recursion_guard: RecursionGuard,
539}
540
541impl CheckedRecursion for Typecheck {
542    fn recursion_guard(&self) -> &RecursionGuard {
543        &self.recursion_guard
544    }
545}
546
547impl Typecheck {
548    /// Creates a typechecking consistency checking pass using a given shared context
549    pub fn new(ctx: SharedContext) -> Self {
550        Self {
551            ctx,
552            disallow_new_globals: false,
553            strict_join_equivalences: false,
554            disallow_dummy: false,
555            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
556        }
557    }
558
559    /// New non-transient global IDs will be treated as an error
560    ///
561    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
562    pub fn disallow_new_globals(mut self) -> Self {
563        self.disallow_new_globals = true;
564        self
565    }
566
567    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
568    ///
569    /// Only turn this on before `JoinImplementation`
570    pub fn strict_join_equivalences(mut self) -> Self {
571        self.strict_join_equivalences = true;
572
573        self
574    }
575
576    /// Disallow dummy values
577    pub fn disallow_dummy(mut self) -> Self {
578        self.disallow_dummy = true;
579        self
580    }
581
582    /// Returns the type of a relation expression or a type error.
583    ///
584    /// This function is careful to check validity, not just find out the type.
585    ///
586    /// It should be linear in the size of the AST.
587    ///
588    /// ??? should we also compute keys and return a `ReprRelationType`?
589    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
590    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
591    ///   reliably keep it stable across transformations.
592    pub fn typecheck<'a>(
593        &self,
594        expr: &'a MirRelationExpr,
595        ctx: &Context,
596    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
597        use MirRelationExpr::*;
598
599        self.checked_recur(|tc| match expr {
600            Constant { typ, rows } => {
601                if let Ok(rows) = rows {
602                    for (row, _id) in rows {
603                        let datums = row.unpack();
604
605                        // correct length
606                        if datums.len() != typ.column_types.len() {
607                            return Err(TypeError::BadConstantRow {
608                                source: expr,
609                                got: row.clone(),
610                                expected: typ.column_types.iter().map(ReprColumnType::from).collect(),
611                            });
612                        }
613
614                        // correct types
615                        if datums
616                            .iter()
617                            .zip_eq(typ.column_types.iter())
618                            .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of_sql(ty))
619                        {
620                            return Err(TypeError::BadConstantRow {
621                                source: expr,
622                                got: row.clone(),
623                                expected: typ.column_types.iter().map(ReprColumnType::from).collect(),
624                            });
625                        }
626
627                        if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
628                            return Err(TypeError::DisallowedDummy {
629                                source: expr,
630                            });
631                        }
632                    }
633                }
634
635                Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
636            }
637            Get { typ, id, .. } => {
638                if let Id::Global(_global_id) = id {
639                    if !ctx.contains_key(id) {
640                        // TODO(mgree) pass QueryContext through to check these types
641                        return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
642                    }
643                }
644
645                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
646                    source: expr,
647                    id: id.clone(),
648                    typ: ReprRelationType::from(typ),
649                })?;
650
651                let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
652
653                // covariant: the ascribed type must be a subtype of the actual type in the context
654                let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
655
656                if !diffs.is_empty() {
657                    return Err(TypeError::MismatchColumns {
658                        source: expr,
659                        got: column_types,
660                        expected: ctx_typ.clone(),
661                        diffs,
662                        message: "annotation did not match context type".to_string(),
663                    });
664                }
665
666                Ok(column_types)
667            }
668            Project { input, outputs } => {
669                let t_in = tc.typecheck(input, ctx)?;
670
671                for x in outputs {
672                    if *x >= t_in.len() {
673                        return Err(TypeError::BadProject {
674                            source: expr,
675                            got: outputs.clone(),
676                            input_type: t_in,
677                        });
678                    }
679                }
680
681                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
682            }
683            Map { input, scalars } => {
684                let mut t_in = tc.typecheck(input, ctx)?;
685
686                for scalar_expr in scalars.iter() {
687                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
688
689                    if self.disallow_dummy && scalar_expr.contains_dummy() {
690                        return Err(TypeError::DisallowedDummy {
691                            source: expr,
692                        });
693                    }
694                }
695
696                Ok(t_in)
697            }
698            FlatMap { input, func, exprs } => {
699                let mut t_in = tc.typecheck(input, ctx)?;
700
701                let mut t_exprs = Vec::with_capacity(exprs.len());
702                for scalar_expr in exprs {
703                    t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
704
705                    if self.disallow_dummy && scalar_expr.contains_dummy() {
706                        return Err(TypeError::DisallowedDummy {
707                            source: expr,
708                        });
709                    }
710                }
711                // TODO(mgree) check t_exprs agrees with `func`'s input type
712
713                let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
714
715                // FlatMap extends the existing columns
716                t_in.extend(t_out);
717                Ok(t_in)
718            }
719            Filter { input, predicates } => {
720                let mut t_in = tc.typecheck(input, ctx)?;
721
722                // Set as nonnull any columns where null values would cause
723                // any predicate to evaluate to null.
724                for column in non_nullable_columns(predicates) {
725                    t_in[column].nullable = false;
726                }
727
728                for scalar_expr in predicates {
729                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
730
731                    // filter condition must be boolean
732                    // ignoring nullability: null is treated as false
733                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
734                    if t.scalar_type != ReprScalarType::Bool {
735                        let sub = t.scalar_type.clone();
736
737                        return Err(TypeError::MismatchColumn {
738                            source: expr,
739                            got: t,
740                            expected: ReprColumnType {
741                                scalar_type: ReprScalarType::Bool,
742                                nullable: true,
743                            },
744                            diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
745                            message: "expected boolean condition".to_string(),
746                        });
747                    }
748
749                    if self.disallow_dummy && scalar_expr.contains_dummy() {
750                        return Err(TypeError::DisallowedDummy {
751                            source: expr,
752                        });
753                    }
754                }
755
756                Ok(t_in)
757            }
758            Join {
759                inputs,
760                equivalences,
761                implementation,
762            } => {
763                let mut t_in_global = Vec::new();
764                let mut t_in_local = vec![Vec::new(); inputs.len()];
765
766                for (i, input) in inputs.iter().enumerate() {
767                    let input_t = tc.typecheck(input, ctx)?;
768                    t_in_global.extend(input_t.clone());
769                    t_in_local[i] = input_t;
770                }
771
772                for eq_class in equivalences {
773                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
774
775                    let mut all_nullable = true;
776
777                    for scalar_expr in eq_class {
778                        // Note: the equivalences have global column references
779                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
780
781                        if !t_expr.nullable {
782                            all_nullable = false;
783                        }
784
785                        if let Some(t_first) = t_exprs.get(0) {
786                            let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
787                            if !diffs.is_empty() {
788                                return Err(TypeError::MismatchColumn {
789                                    source: expr,
790                                    got: t_expr,
791                                    expected: t_first.clone(),
792                                    diffs,
793                                    message: "equivalence class members have different scalar types".to_string(),
794                                });
795                            }
796
797                            // equivalences may or may not match on nullability
798                            // before JoinImplementation runs, nullability should match.
799                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
800                            if self.strict_join_equivalences {
801                                if t_expr.nullable != t_first.nullable {
802                                    let sub = t_expr.clone();
803                                    let sup = t_first.clone();
804
805                                    let err = TypeError::MismatchColumn {
806                                        source: expr,
807                                        got: t_expr.clone(),
808                                        expected: t_first.clone(),
809                                        diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
810                                        message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
811                                    };
812
813                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
814                                    ::tracing::debug!("{err}");
815                                }
816                            }
817                        }
818
819                        if self.disallow_dummy && scalar_expr.contains_dummy() {
820                            return Err(TypeError::DisallowedDummy {
821                                source: expr,
822                            });
823                        }
824
825                        t_exprs.push(t_expr);
826                    }
827
828                    if self.strict_join_equivalences && all_nullable {
829                        let err = TypeError::BadJoinEquivalence {
830                            source: expr,
831                            got: t_exprs,
832                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
833                        };
834
835                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
836                        ::tracing::debug!("{err}");
837                    }
838                }
839
840                // check that the join implementation is consistent
841                match implementation {
842                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
843                        if let Some(key) = first_key {
844                            for k in key {
845                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
846                            }
847                        }
848
849                        for (idx, key, _) in others {
850                            for k in key {
851                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
852                            }
853                        }
854                    }
855                    JoinImplementation::DeltaQuery(plans) => {
856                        for plan in plans {
857                            for (idx, key, _) in plan {
858                                for k in key {
859                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
860                                }
861                            }
862                        }
863                    }
864                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
865                        let typ: Vec<ReprColumnType> = key
866                            .iter()
867                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
868                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
869
870                        for row in consts {
871                            let datums = row.unpack();
872
873                            // correct length
874                            if datums.len() != typ.len() {
875                                return Err(TypeError::BadConstantRow {
876                                    source: expr,
877                                    got: row.clone(),
878                                    expected: typ,
879                                });
880                            }
881
882                            // correct types
883                            if datums
884                                .iter()
885                                .zip_eq(typ.iter())
886                                .any(|(d, ty)| d != &mz_repr::Datum::Dummy && !d.is_instance_of(ty))
887                            {
888                                return Err(TypeError::BadConstantRow {
889                                    source: expr,
890                                    got: row.clone(),
891                                    expected: typ,
892                                });
893                            }
894                        }
895                    }
896                    JoinImplementation::Unimplemented => (),
897                }
898
899                Ok(t_in_global)
900            }
901            Reduce {
902                input,
903                group_key,
904                aggregates,
905                monotonic: _,
906                expected_group_size: _,
907            } => {
908                let t_in = tc.typecheck(input, ctx)?;
909
910                let mut t_out = group_key
911                    .iter()
912                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
913                    .collect::<Result<Vec<_>, _>>()?;
914
915                    if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
916                        return Err(TypeError::DisallowedDummy {
917                            source: expr,
918                        });
919                    }
920
921                for agg in aggregates {
922                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
923                }
924
925                Ok(t_out)
926            }
927            TopK {
928                input,
929                group_key,
930                order_key,
931                limit: _,
932                offset: _,
933                monotonic: _,
934                expected_group_size: _,
935            } => {
936                let t_in = tc.typecheck(input, ctx)?;
937
938                for &k in group_key {
939                    if k >= t_in.len() {
940                        return Err(TypeError::BadTopKGroupKey {
941                            source: expr,
942                            k,
943                            input_type: t_in,
944                        });
945                    }
946                }
947
948                for order in order_key {
949                    if order.column >= t_in.len() {
950                        return Err(TypeError::BadTopKOrdering {
951                            source: expr,
952                            order: order.clone(),
953                            input_type: t_in,
954                        });
955                    }
956                }
957
958                Ok(t_in)
959            }
960            Negate { input } => tc.typecheck(input, ctx),
961            Threshold { input } => tc.typecheck(input, ctx),
962            Union { base, inputs } => {
963                let mut t_base = tc.typecheck(base, ctx)?;
964
965                for input in inputs {
966                    let t_input = tc.typecheck(input, ctx)?;
967
968                    let len_sub = t_base.len();
969                    let len_sup = t_input.len();
970                    if len_sub != len_sup {
971                        return Err(TypeError::MismatchColumns {
972                            source: expr,
973                            got: t_base.clone(),
974                            expected: t_input,
975                            diffs: vec![ReprRelationTypeDifference::Length {
976                                len_sub,
977                                len_sup,
978                            }],
979                            message: "Union branches have different numbers of columns".to_string(),
980                        });
981                    }
982
983                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
984                        let diffs = column_union(base_col, &input_col);
985                        if !diffs.is_empty() {
986                            return Err(TypeError::MismatchColumn {
987                                    source: expr,
988                                    got: input_col,
989                                    expected: base_col.clone(),
990                                    diffs,
991                                    message:
992                                        "couldn't compute union of column types in Union"
993                                    .to_string(),
994                            });
995                        }
996
997                    }
998                }
999
1000                Ok(t_base)
1001            }
1002            Let { id, value, body } => {
1003                let t_value = tc.typecheck(value, ctx)?;
1004
1005                let binding = Id::Local(*id);
1006                if ctx.contains_key(&binding) {
1007                    return Err(TypeError::Shadowing {
1008                        source: expr,
1009                        id: binding,
1010                    });
1011                }
1012
1013                let mut body_ctx = ctx.clone();
1014                body_ctx.insert(Id::Local(*id), t_value);
1015
1016                tc.typecheck(body, &body_ctx)
1017            }
1018            LetRec { ids, values, body, limits: _ } => {
1019                if ids.len() != values.len() {
1020                    return Err(TypeError::BadLetRecBindings { source: expr });
1021                }
1022
1023                // temporary hack: steal info from the Gets inside to learn the expected types
1024                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1025                let mut ctx = ctx.clone();
1026                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1027                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1028                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1029                }
1030
1031                for (id, value) in ids.iter().zip_eq(values.iter()) {
1032                    let typ = tc.typecheck(value, &ctx)?;
1033
1034                    let id = Id::Local(id.clone());
1035                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1036                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1037                            // we expect an EXACT match, but don't care about nullability
1038                            let diffs = column_union(base_col, &input_col);
1039                            if !diffs.is_empty() {
1040                                 return Err(TypeError::MismatchColumn {
1041                                        source: expr,
1042                                        got: input_col,
1043                                        expected: base_col.clone(),
1044                                        diffs,
1045                                        message:
1046                                            "couldn't compute union of column types in LetRec"
1047                                        .to_string(),
1048                                    })
1049                            }
1050                        }
1051                    } else {
1052                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1053                        ctx.insert(id, typ);
1054                    }
1055                }
1056
1057                tc.typecheck(body, &ctx)
1058            }
1059            ArrangeBy { input, keys } => {
1060                let t_in = tc.typecheck(input, ctx)?;
1061
1062                for key in keys {
1063                    for k in key {
1064                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1065                    }
1066                }
1067
1068                Ok(t_in)
1069            }
1070        })
1071    }
1072
1073    /// Traverses a term to collect the types of given ids.
1074    ///
1075    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1076    fn collect_recursive_variable_types<'a>(
1077        &self,
1078        expr: &'a MirRelationExpr,
1079        ids: &[LocalId],
1080        ctx: &mut Context,
1081    ) -> Result<(), TypeError<'a>> {
1082        use MirRelationExpr::*;
1083
1084        self.checked_recur(|tc| {
1085            match expr {
1086                Get {
1087                    id: Id::Local(id),
1088                    typ,
1089                    ..
1090                } => {
1091                    if !ids.contains(id) {
1092                        return Ok(());
1093                    }
1094
1095                    let id = Id::Local(id.clone());
1096                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1097                        let typ = typ
1098                            .column_types
1099                            .iter()
1100                            .map(ReprColumnType::from)
1101                            .collect_vec();
1102
1103                        if ctx_typ.len() != typ.len() {
1104                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1105
1106                            return Err(TypeError::MismatchColumns {
1107                                source: expr,
1108                                got: typ,
1109                                expected: ctx_typ.clone(),
1110                                diffs,
1111                                message: "environment and type annotation did not match"
1112                                    .to_string(),
1113                            });
1114                        }
1115
1116                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1117                            let diffs = column_union(base_col, &input_col);
1118                            if !diffs.is_empty() {
1119                                return Err(TypeError::MismatchColumn {
1120                                    source: expr,
1121                                    got: input_col,
1122                                    expected: base_col.clone(),
1123                                    diffs,
1124                                    message:
1125                                        "couldn't compute union of column types in Get and context"
1126                                            .to_string(),
1127                                });
1128                            }
1129                        }
1130                    } else {
1131                        ctx.insert(
1132                            id,
1133                            typ.column_types
1134                                .iter()
1135                                .map(ReprColumnType::from)
1136                                .collect_vec(),
1137                        );
1138                    }
1139                }
1140                Get {
1141                    id: Id::Global(..), ..
1142                }
1143                | Constant { .. } => (),
1144                Let { id, value, body } => {
1145                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1146
1147                    // we've shadowed the id
1148                    if ids.contains(id) {
1149                        return Err(TypeError::Shadowing {
1150                            source: expr,
1151                            id: Id::Local(*id),
1152                        });
1153                    }
1154
1155                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1156                }
1157                LetRec {
1158                    ids: inner_ids,
1159                    values,
1160                    body,
1161                    limits: _,
1162                } => {
1163                    for inner_id in inner_ids {
1164                        if ids.contains(inner_id) {
1165                            return Err(TypeError::Shadowing {
1166                                source: expr,
1167                                id: Id::Local(*inner_id),
1168                            });
1169                        }
1170                    }
1171
1172                    for value in values {
1173                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1174                    }
1175
1176                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1177                }
1178                Project { input, .. }
1179                | Map { input, .. }
1180                | FlatMap { input, .. }
1181                | Filter { input, .. }
1182                | Reduce { input, .. }
1183                | TopK { input, .. }
1184                | Negate { input }
1185                | Threshold { input }
1186                | ArrangeBy { input, .. } => {
1187                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1188                }
1189                Join { inputs, .. } => {
1190                    for input in inputs {
1191                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1192                    }
1193                }
1194                Union { base, inputs } => {
1195                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1196
1197                    for input in inputs {
1198                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1199                    }
1200                }
1201            }
1202
1203            Ok(())
1204        })
1205    }
1206
1207    fn typecheck_scalar<'a>(
1208        &self,
1209        expr: &'a MirScalarExpr,
1210        source: &'a MirRelationExpr,
1211        column_types: &[ReprColumnType],
1212    ) -> Result<ReprColumnType, TypeError<'a>> {
1213        use MirScalarExpr::*;
1214
1215        self.checked_recur(|tc| match expr {
1216            Column(i, _) => match column_types.get(*i) {
1217                Some(ty) => Ok(ty.clone()),
1218                None => Err(TypeError::NoSuchColumn {
1219                    source,
1220                    expr,
1221                    col: *i,
1222                }),
1223            },
1224            Literal(row, typ) => {
1225                let typ = ReprColumnType::from(typ);
1226                if let Ok(row) = row {
1227                    let datums = row.unpack();
1228
1229                    if datums.len() != 1
1230                        || (datums[0] != mz_repr::Datum::Dummy && !datums[0].is_instance_of(&typ))
1231                    {
1232                        return Err(TypeError::BadConstantRow {
1233                            source,
1234                            got: row.clone(),
1235                            expected: vec![typ],
1236                        });
1237                    }
1238                }
1239
1240                Ok(typ)
1241            }
1242            CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1243            CallUnary { expr, func } => {
1244                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1245                let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1246                Ok(ReprColumnType::from(&typ_out))
1247            }
1248            CallBinary { expr1, expr2, func } => {
1249                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1250                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1251                let typ_out = func.output_type(
1252                    SqlColumnType::from_repr(&typ_in1),
1253                    SqlColumnType::from_repr(&typ_in2),
1254                );
1255                Ok(ReprColumnType::from(&typ_out))
1256            }
1257            CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1258                &func.output_type(
1259                    exprs
1260                        .iter()
1261                        .map(|e| {
1262                            tc.typecheck_scalar(e, source, column_types)
1263                                .map(|typ| SqlColumnType::from_repr(&typ))
1264                        })
1265                        .collect::<Result<Vec<_>, TypeError>>()?,
1266                ),
1267            )),
1268            If { cond, then, els } => {
1269                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1270
1271                // condition must be boolean
1272                // ignoring nullability: null is treated as false
1273                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1274                if cond_type.scalar_type != ReprScalarType::Bool {
1275                    let sub = cond_type.scalar_type.clone();
1276
1277                    return Err(TypeError::MismatchColumn {
1278                        source,
1279                        got: cond_type,
1280                        expected: ReprColumnType {
1281                            scalar_type: ReprScalarType::Bool,
1282                            nullable: true,
1283                        },
1284                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1285                            sub,
1286                            sup: ReprScalarType::Bool,
1287                        }],
1288                        message: "expected boolean condition".to_string(),
1289                    });
1290                }
1291
1292                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1293                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1294
1295                let diffs = column_union(&mut then_type, &else_type);
1296                if !diffs.is_empty() {
1297                    return Err(TypeError::MismatchColumn {
1298                        source,
1299                        got: then_type,
1300                        expected: else_type,
1301                        diffs,
1302                        message: "couldn't compute union of column types for If".to_string(),
1303                    });
1304                }
1305
1306                Ok(then_type)
1307            }
1308        })
1309    }
1310
1311    /// Typecheck an `AggregateExpr`
1312    pub fn typecheck_aggregate<'a>(
1313        &self,
1314        expr: &'a AggregateExpr,
1315        source: &'a MirRelationExpr,
1316        column_types: &[ReprColumnType],
1317    ) -> Result<ReprColumnType, TypeError<'a>> {
1318        self.checked_recur(|tc| {
1319            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1320
1321            // TODO check that t_in is actually acceptable for `func`
1322
1323            Ok(ReprColumnType::from(
1324                &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1325            ))
1326        })
1327    }
1328}
1329
1330/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1331///
1332/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1333macro_rules! type_error {
1334    ($severity:expr, $($arg:tt)+) => {{
1335        if $severity {
1336          soft_panic_or_log!($($arg)+);
1337        } else {
1338          ::tracing::debug!($($arg)+);
1339        }
1340    }}
1341}
1342
1343impl crate::Transform for Typecheck {
1344    fn name(&self) -> &'static str {
1345        "Typecheck"
1346    }
1347
1348    fn actually_perform_transform(
1349        &self,
1350        relation: &mut MirRelationExpr,
1351        transform_ctx: &mut crate::TransformCtx,
1352    ) -> Result<(), crate::TransformError> {
1353        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1354
1355        let expected = transform_ctx
1356            .global_id
1357            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1358
1359        if let Some(id) = transform_ctx.global_id {
1360            if self.disallow_new_globals
1361                && expected.is_none()
1362                && transform_ctx.global_id.is_some()
1363                && !id.is_transient()
1364            {
1365                type_error!(
1366                    false, // not severe
1367                    "type warning: new non-transient global id {id}\n{}",
1368                    relation.pretty()
1369                );
1370            }
1371        }
1372
1373        let got = self.typecheck(relation, &typecheck_ctx);
1374
1375        let humanizer = mz_repr::explain::DummyHumanizer;
1376
1377        match (got, expected) {
1378            (Ok(got), Some(expected)) => {
1379                let id = transform_ctx.global_id.unwrap();
1380
1381                // contravariant: global types can be updated
1382                let diffs = relation_subtype_difference(expected, &got);
1383                if !diffs.is_empty() {
1384                    // SEVERE only if got and expected have true differences, not just nullability
1385                    let severity = diffs
1386                        .iter()
1387                        .any(|diff| diff.clone().ignore_nullability().is_some());
1388
1389                    let err = TypeError::MismatchColumns {
1390                        source: relation,
1391                        got,
1392                        expected: expected.clone(),
1393                        diffs,
1394                        message: format!(
1395                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1396                        ),
1397                    };
1398
1399                    type_error!(severity, "type error in known global id {id}:\n{err}");
1400                }
1401            }
1402            (Ok(got), None) => {
1403                if let Some(id) = transform_ctx.global_id {
1404                    typecheck_ctx.insert(Id::Global(id), got);
1405                }
1406            }
1407            (Err(err), _) => {
1408                let (expected, binding) = match expected {
1409                    Some(expected) => {
1410                        let id = transform_ctx.global_id.unwrap();
1411                        (
1412                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1413                            format!("known global id {id}"),
1414                        )
1415                    }
1416                    None => ("".to_string(), "transient query".to_string()),
1417                };
1418
1419                type_error!(
1420                    true, // SEVERE: the transformed code is inconsistent
1421                    "type error in {binding}:\n{err}\n{expected}{}",
1422                    relation.pretty()
1423                );
1424            }
1425        }
1426
1427        Ok(())
1428    }
1429}
1430
1431/// Prints a type prettily with a given `ExprHumanizer`
1432pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1433where
1434    H: ExprHumanizer,
1435{
1436    let mut s = String::with_capacity(2 + 3 * cols.len());
1437
1438    s.push('(');
1439
1440    let mut it = cols.iter().peekable();
1441    while let Some(col) = it.next() {
1442        s.push_str(&humanizer.humanize_column_type_repr(col, false));
1443
1444        if it.peek().is_some() {
1445            s.push_str(", ");
1446        }
1447    }
1448
1449    s.push(')');
1450
1451    s
1452}
1453
1454impl ReprRelationTypeDifference {
1455    /// Pretty prints a type difference
1456    ///
1457    /// Always indents two spaces
1458    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1459    where
1460        H: ExprHumanizer,
1461    {
1462        use ReprRelationTypeDifference::*;
1463        match self {
1464            Length { len_sub, len_sup } => {
1465                writeln!(
1466                    f,
1467                    "  number of columns do not match ({len_sub} != {len_sup})"
1468                )
1469            }
1470            Column { col, diff } => {
1471                writeln!(f, "  column {col} differs:")?;
1472                diff.humanize(4, h, f)
1473            }
1474        }
1475    }
1476}
1477
1478impl ReprColumnTypeDifference {
1479    /// Pretty prints a type difference at a given indentation level
1480    pub fn humanize<H>(
1481        &self,
1482        indent: usize,
1483        h: &H,
1484        f: &mut std::fmt::Formatter<'_>,
1485    ) -> std::fmt::Result
1486    where
1487        H: ExprHumanizer,
1488    {
1489        use ReprColumnTypeDifference::*;
1490
1491        // indent
1492        write!(f, "{:indent$}", "")?;
1493
1494        match self {
1495            NotSubtype { sub, sup } => {
1496                let sub = h.humanize_scalar_type_repr(sub, false);
1497                let sup = h.humanize_scalar_type_repr(sup, false);
1498
1499                writeln!(f, "{sub} is a not a subtype of {sup}")
1500            }
1501            Nullability { sub, sup } => {
1502                let sub = h.humanize_column_type_repr(sub, false);
1503                let sup = h.humanize_column_type_repr(sup, false);
1504
1505                writeln!(f, "{sub} is nullable but {sup} is not")
1506            }
1507            ElementType { ctor, element_type } => {
1508                writeln!(f, "{ctor} element types differ:")?;
1509
1510                element_type.humanize(indent + 2, h, f)
1511            }
1512            RecordMissingFields { missing } => {
1513                write!(f, "missing column fields:")?;
1514                for col in missing {
1515                    write!(f, " {col}")?;
1516                }
1517                f.write_char('\n')
1518            }
1519            RecordFields { fields } => {
1520                writeln!(f, "{} record fields differ:", fields.len())?;
1521
1522                for (i, diff) in fields.iter().enumerate() {
1523                    writeln!(f, "{:indent$}  field {i}:", "")?;
1524                    diff.humanize(indent + 4, h, f)?;
1525                }
1526                Ok(())
1527            }
1528        }
1529    }
1530}
1531
1532/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1533#[allow(missing_debug_implementations)]
1534pub struct TypeErrorHumanizer<'a, 'b, H>
1535where
1536    H: ExprHumanizer,
1537{
1538    err: &'a TypeError<'a>,
1539    humanizer: &'b H,
1540}
1541
1542impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1543where
1544    H: ExprHumanizer,
1545{
1546    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1547    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1548        Self { err, humanizer }
1549    }
1550}
1551
1552impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1553where
1554    H: ExprHumanizer,
1555{
1556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1557        self.err.humanize(self.humanizer, f)
1558    }
1559}
1560
1561impl<'a> std::fmt::Display for TypeError<'a> {
1562    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1563        TypeErrorHumanizer {
1564            err: self,
1565            humanizer: &DummyHumanizer,
1566        }
1567        .fmt(f)
1568    }
1569}
1570
1571impl<'a> TypeError<'a> {
1572    /// The source of the type error
1573    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1574        use TypeError::*;
1575        match self {
1576            Unbound { source, .. }
1577            | NoSuchColumn { source, .. }
1578            | MismatchColumn { source, .. }
1579            | MismatchColumns { source, .. }
1580            | BadConstantRow { source, .. }
1581            | BadProject { source, .. }
1582            | BadJoinEquivalence { source, .. }
1583            | BadTopKGroupKey { source, .. }
1584            | BadTopKOrdering { source, .. }
1585            | BadLetRecBindings { source }
1586            | Shadowing { source, .. }
1587            | DisallowedDummy { source, .. } => Some(source),
1588            Recursion { .. } => None,
1589        }
1590    }
1591
1592    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1593    where
1594        H: ExprHumanizer,
1595    {
1596        if let Some(source) = self.source() {
1597            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1598        }
1599
1600        use TypeError::*;
1601        match self {
1602            Unbound { source: _, id, typ } => {
1603                let typ = columns_pretty(&typ.column_types, humanizer);
1604                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1605            }
1606            NoSuchColumn {
1607                source: _,
1608                expr,
1609                col,
1610            } => writeln!(f, "{expr} references non-existent column {col}")?,
1611            MismatchColumn {
1612                source: _,
1613                got,
1614                expected,
1615                diffs,
1616                message,
1617            } => {
1618                let got = humanizer.humanize_column_type_repr(got, false);
1619                let expected = humanizer.humanize_column_type_repr(expected, false);
1620                writeln!(
1621                    f,
1622                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1623                )?;
1624
1625                for diff in diffs {
1626                    diff.humanize(2, humanizer, f)?;
1627                }
1628            }
1629            MismatchColumns {
1630                source: _,
1631                got,
1632                expected,
1633                diffs,
1634                message,
1635            } => {
1636                let got = columns_pretty(got, humanizer);
1637                let expected = columns_pretty(expected, humanizer);
1638
1639                writeln!(
1640                    f,
1641                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1642                )?;
1643
1644                for diff in diffs {
1645                    diff.humanize(humanizer, f)?;
1646                }
1647            }
1648            BadConstantRow {
1649                source: _,
1650                got,
1651                expected,
1652            } => {
1653                let expected = columns_pretty(expected, humanizer);
1654
1655                writeln!(
1656                    f,
1657                    "bad constant row\n      got {got}\nexpected row of type {expected}"
1658                )?
1659            }
1660            BadProject {
1661                source: _,
1662                got,
1663                input_type,
1664            } => {
1665                let input_type = columns_pretty(input_type, humanizer);
1666
1667                writeln!(
1668                    f,
1669                    "projection of non-existant columns {got:?} from type {input_type}"
1670                )?
1671            }
1672            BadJoinEquivalence {
1673                source: _,
1674                got,
1675                message,
1676            } => {
1677                let got = columns_pretty(got, humanizer);
1678
1679                writeln!(f, "bad join equivalence {got}: {message}")?
1680            }
1681            BadTopKGroupKey {
1682                source: _,
1683                k,
1684                input_type,
1685            } => {
1686                let input_type = columns_pretty(input_type, humanizer);
1687
1688                writeln!(
1689                    f,
1690                    "TopK group key component references invalid column {k} in columns: {input_type}"
1691                )?
1692            }
1693            BadTopKOrdering {
1694                source: _,
1695                order,
1696                input_type,
1697            } => {
1698                let col = order.column;
1699                let num_cols = input_type.len();
1700                let are = if num_cols == 1 { "is" } else { "are" };
1701                let s = if num_cols == 1 { "" } else { "s" };
1702                let input_type = columns_pretty(input_type, humanizer);
1703
1704                // TODO(cloud#8196)
1705                let mode = HumanizedExplain::new(false);
1706                let order = mode.expr(order, None);
1707
1708                writeln!(
1709                    f,
1710                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
1711                )?
1712            }
1713            BadLetRecBindings { source: _ } => {
1714                writeln!(f, "LetRec ids and definitions don't line up")?
1715            }
1716            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
1717            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
1718            Recursion { error } => writeln!(f, "{error}")?,
1719        }
1720
1721        Ok(())
1722    }
1723}