Skip to main content

mz_transform/
reprtypecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27    ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28    SqlColumnType,
29};
30
31/// Typechecking contexts as shared by various typechecking passes.
32///
33/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
34/// Shared contexts help catch consistency issues.
35pub type SharedContext = Arc<Mutex<Context>>;
36
37/// Generates an empty context
38pub fn empty_context() -> SharedContext {
39    Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42/// The possible forms of inconsistency/errors discovered during typechecking.
43///
44/// Every variant has a `source` field identifying the MIR term that is home
45/// to the error (though not necessarily the root cause of the error).
46#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48    /// Unbound identifiers (local or global)
49    Unbound {
50        /// Expression with the bug
51        source: &'a MirRelationExpr,
52        /// The (unbound) identifier referenced
53        id: Id,
54        /// The type `id` was expected to have
55        typ: ReprRelationType,
56    },
57    /// Dereference of a non-existent column
58    NoSuchColumn {
59        /// Expression with the bug
60        source: &'a MirRelationExpr,
61        /// Scalar expression that references an invalid column
62        expr: &'a MirScalarExpr,
63        /// The invalid column referenced
64        col: usize,
65    },
66    /// A single column type does not match
67    MismatchColumn {
68        /// Expression with the bug
69        source: &'a MirRelationExpr,
70        /// The column type we found (`sub` type)
71        got: ReprColumnType,
72        /// The column type we expected (`sup` type)
73        expected: ReprColumnType,
74        /// The difference between these types
75        diffs: Vec<ReprColumnTypeDifference>,
76        /// An explanatory message
77        message: String,
78    },
79    /// Relation column types do not match
80    MismatchColumns {
81        /// Expression with the bug
82        source: &'a MirRelationExpr,
83        /// The column types we found (`sub` type)
84        got: Vec<ReprColumnType>,
85        /// The column types we expected (`sup` type)
86        expected: Vec<ReprColumnType>,
87        /// The difference between these types
88        diffs: Vec<ReprRelationTypeDifference>,
89        /// An explanatory message
90        message: String,
91    },
92    /// A constant row does not have the correct type
93    BadConstantRowLen {
94        /// Expression with the bug
95        source: &'a MirRelationExpr,
96        /// A constant row
97        got: usize,
98        /// The expected type (which that row does not have)
99        expected: Vec<ReprColumnType>,
100    },
101    /// A constant row does not have the correct type
102    BadConstantRow {
103        /// Expression with the bug
104        source: &'a MirRelationExpr,
105        /// The columns that mismatched, with a debug rendering of the datum we got and the expected type
106        mismatches: Vec<(usize, DatumTypeDifference)>,
107        /// The expected type (which that row does not have)
108        expected: Vec<ReprColumnType>,
109        // TODO(mgree) with a good way to get the type of a Datum, we could give a detailed diff here
110        // this is difficult for empty structures where we may not know the element type
111    },
112    /// Projection of a non-existent column
113    BadProject {
114        /// Expression with the bug
115        source: &'a MirRelationExpr,
116        /// The column projected
117        got: Vec<usize>,
118        /// The input columns (which don't have that column)
119        input_type: Vec<ReprColumnType>,
120    },
121    /// An equivalence class in a join was malformed
122    BadJoinEquivalence {
123        /// Expression with the bug
124        source: &'a MirRelationExpr,
125        /// The join equivalences
126        got: Vec<ReprColumnType>,
127        /// The problem with the join equivalences
128        message: String,
129    },
130    /// TopK grouping by non-existent column
131    BadTopKGroupKey {
132        /// Expression with the bug
133        source: &'a MirRelationExpr,
134        /// The bad column reference in the group key
135        k: usize,
136        /// The input columns (which don't have that column)
137        input_type: Vec<ReprColumnType>,
138    },
139    /// TopK ordering by non-existent column
140    BadTopKOrdering {
141        /// Expression with the bug
142        source: &'a MirRelationExpr,
143        /// The ordering used
144        order: ColumnOrder,
145        /// The input columns (which don't work for that ordering)
146        input_type: Vec<ReprColumnType>,
147    },
148    /// LetRec bindings are malformed
149    BadLetRecBindings {
150        /// Expression with the bug
151        source: &'a MirRelationExpr,
152    },
153    /// Local identifiers are shadowed
154    Shadowing {
155        /// Expression with the bug
156        source: &'a MirRelationExpr,
157        /// The id that was shadowed
158        id: Id,
159    },
160    /// Recursion depth exceeded
161    Recursion {
162        /// The error that aborted recursion
163        error: RecursionLimitError,
164    },
165    /// A dummy value was found
166    DisallowedDummy {
167        /// The expression with the dummy value
168        source: &'a MirRelationExpr,
169    },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173    fn from(error: RecursionLimitError) -> Self {
174        TypeError::Recursion { error }
175    }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180/// Characterizes differences between relation types
181///
182/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
183#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185    /// `sub` and `sup` don't have the same number of columns
186    Length {
187        /// Length of `sub`
188        len_sub: usize,
189        /// Length of `sup`
190        len_sup: usize,
191    },
192    /// `sub` and `sup` differ at the indicated column
193    Column {
194        /// The column at which `sub` and `sup` differ
195        col: usize,
196        /// The difference between `sub` and `sup`
197        diff: ReprColumnTypeDifference,
198    },
199}
200
201/// Characterizes differences between individual column types
202///
203/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
204/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
205#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
208    NotSubtype {
209        /// Would-be subtype
210        sub: ReprScalarType,
211        /// Would-be supertype
212        sup: ReprScalarType,
213    },
214    /// `sub` was nullable but `sup` was not
215    Nullability {
216        /// Would-be subtype
217        sub: ReprColumnType,
218        /// Would-be supertype
219        sup: ReprColumnType,
220    },
221    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
222    ElementType {
223        /// The type constructor (list, array, etc.)
224        ctor: String,
225        /// The difference in the element type
226        element_type: Box<ReprColumnTypeDifference>,
227    },
228    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
229    RecordMissingFields {
230        /// The missing fields
231        missing: Vec<ColumnName>,
232    },
233    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
234    RecordFields {
235        /// The differences, by field
236        fields: Vec<ReprColumnTypeDifference>,
237    },
238}
239
240impl ReprRelationTypeDifference {
241    /// Returns the same type difference, but ignoring nullability
242    ///
243    /// Returns `None` when _all_ of the differences are due to nullability
244    pub fn ignore_nullability(self) -> Option<Self> {
245        use ReprRelationTypeDifference::*;
246
247        match self {
248            Length { .. } => Some(self),
249            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250        }
251    }
252}
253
254impl ReprColumnTypeDifference {
255    /// Returns the same type difference, but ignoring nullability
256    ///
257    /// Returns `None` when _all_ of the differences are due to nullability
258    pub fn ignore_nullability(self) -> Option<Self> {
259        use ReprColumnTypeDifference::*;
260
261        match self {
262            Nullability { .. } => None,
263            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264            ElementType { ctor, element_type } => {
265                element_type
266                    .ignore_nullability()
267                    .map(|element_type| ElementType {
268                        ctor,
269                        element_type: Box::new(element_type),
270                    })
271            }
272            RecordFields { fields } => {
273                let fields = fields
274                    .into_iter()
275                    .flat_map(|diff| diff.ignore_nullability())
276                    .collect::<Vec<_>>();
277
278                if fields.is_empty() {
279                    None
280                } else {
281                    Some(RecordFields { fields })
282                }
283            }
284        }
285    }
286}
287
288/// Returns a list of differences that make `sub` not a subtype of `sup`
289///
290/// This function returns an empty list when `sub` is a subtype of `sup`
291pub fn relation_subtype_difference(
292    sub: &[ReprColumnType],
293    sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295    let mut diffs = Vec::new();
296
297    if sub.len() != sup.len() {
298        diffs.push(ReprRelationTypeDifference::Length {
299            len_sub: sub.len(),
300            len_sup: sup.len(),
301        });
302
303        // TODO(mgree) we could do an edit-distance computation to report more errors
304        return diffs;
305    }
306
307    diffs.extend(
308        sub.iter()
309            .zip_eq(sup.iter())
310            .enumerate()
311            .flat_map(|(col, (sub_ty, sup_ty))| {
312                column_subtype_difference(sub_ty, sup_ty)
313                    .into_iter()
314                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315            }),
316    );
317
318    diffs
319}
320
321/// Returns a list of differences that make `sub` not a subtype of `sup`
322///
323/// This function returns an empty list when `sub` is a subtype of `sup`
324pub fn column_subtype_difference(
325    sub: &ReprColumnType,
326    sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330    if sub.nullable && !sup.nullable {
331        diffs.push(ReprColumnTypeDifference::Nullability {
332            sub: sub.clone(),
333            sup: sup.clone(),
334        });
335    }
336
337    diffs
338}
339
340/// Returns a list of differences that make `sub` not a subtype of `sup`
341///
342/// This function returns an empty list when `sub` is a subtype of `sup`
343pub fn scalar_subtype_difference(
344    sub: &ReprScalarType,
345    sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347    use ReprScalarType::*;
348
349    let mut diffs = Vec::new();
350
351    match (sub, sup) {
352        (
353            List {
354                element_type: sub_elt,
355                ..
356            },
357            List {
358                element_type: sup_elt,
359                ..
360            },
361        )
362        | (
363            Map {
364                value_type: sub_elt,
365                ..
366            },
367            Map {
368                value_type: sup_elt,
369                ..
370            },
371        )
372        | (
373            Range {
374                element_type: sub_elt,
375                ..
376            },
377            Range {
378                element_type: sup_elt,
379                ..
380            },
381        )
382        | (Array(sub_elt), Array(sup_elt)) => {
383            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384            diffs.extend(
385                scalar_subtype_difference(sub_elt, sup_elt)
386                    .into_iter()
387                    .map(|diff| ReprColumnTypeDifference::ElementType {
388                        ctor: ctor.clone(),
389                        element_type: Box::new(diff),
390                    }),
391            );
392        }
393        (
394            Record {
395                fields: sub_fields, ..
396            },
397            Record {
398                fields: sup_fields, ..
399            },
400        ) => {
401            if sub_fields.len() != sup_fields.len() {
402                diffs.push(ReprColumnTypeDifference::NotSubtype {
403                    sub: sub.clone(),
404                    sup: sup.clone(),
405                });
406                return diffs;
407            }
408
409            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411            }
412        }
413        (_, _) => {
414            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415                diffs.push(ReprColumnTypeDifference::NotSubtype {
416                    sub: sub.clone(),
417                    sup: sup.clone(),
418                })
419            }
420        }
421    };
422
423    diffs
424}
425
426/// Unions `other` into `typ`, returning a list of differences on failure
427///
428/// This function returns an empty list when `typ` and `other` are a union
429pub fn scalar_union(
430    typ: &mut ReprScalarType,
431    other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433    use ReprScalarType::*;
434
435    let mut diffs = Vec::new();
436
437    // precomputing to appease the borrow checker
438    let ctor = ReprScalarBaseType::from(&*typ);
439    match (typ, other) {
440        (
441            List {
442                element_type: typ_elt,
443            },
444            List {
445                element_type: other_elt,
446            },
447        )
448        | (
449            Map {
450                value_type: typ_elt,
451            },
452            Map {
453                value_type: other_elt,
454            },
455        )
456        | (
457            Range {
458                element_type: typ_elt,
459            },
460            Range {
461                element_type: other_elt,
462            },
463        )
464        | (Array(typ_elt), Array(other_elt)) => {
465            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466            diffs.extend(
467                res.into_iter()
468                    .map(|diff| ReprColumnTypeDifference::ElementType {
469                        ctor: format!("{ctor:?}"),
470                        element_type: Box::new(diff),
471                    }),
472            );
473        }
474        (
475            Record { fields: typ_fields },
476            Record {
477                fields: other_fields,
478            },
479        ) => {
480            if typ_fields.len() != other_fields.len() {
481                diffs.push(ReprColumnTypeDifference::NotSubtype {
482                    sub: ReprScalarType::Record {
483                        fields: typ_fields.clone(),
484                    },
485                    sup: other.clone(),
486                });
487                return diffs;
488            }
489
490            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491                diffs.extend(column_union(typ_ty, other_ty));
492            }
493        }
494        (typ, _) => {
495            if ctor != ReprScalarBaseType::from(other) {
496                diffs.push(ReprColumnTypeDifference::NotSubtype {
497                    sub: typ.clone(),
498                    sup: other.clone(),
499                })
500            }
501        }
502    };
503
504    diffs
505}
506
507/// Unions `other` into `typ`, returning a list of differences on failure
508///
509/// This function returns an empty list when `typ` and `other` are a union
510pub fn column_union(
511    typ: &mut ReprColumnType,
512    other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516    if diffs.is_empty() {
517        typ.nullable |= other.nullable;
518    }
519
520    diffs
521}
522
523/// Returns true when it is safe to treat a `sub` row as an `sup` row
524///
525/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
526/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
527pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528    if sub.len() != sup.len() {
529        return false;
530    }
531
532    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534    })
535}
536
537/// Characterizes how a Datum differs from a ReprColumnType
538#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540    /// Datum was null, but expected a non-null value (i.e., a ReprScalarType)
541    Null {
542        /// The representation scalar type that we expected
543        expected: ReprScalarType,
544    },
545    /// Datum's kind doesn't match the expected ReprScalarType
546    Mismatch {
547        /// The debug rendering of the Datum that we got (we don't store the Datum itself to avoid borrowing issues)
548        /// This debug rendering includes the DatumKind.
549        got_debug: String,
550        /// The representation scalar type that we expected
551        expected: ReprScalarType,
552    },
553    /// Datum's dimensions didn't match the type's dimensions
554    MismatchDimensions {
555        /// The type constructor (list, array, int2vector, etc.)
556        ctor: String,
557        /// The dimension of the datum that we got
558        got: usize,
559        /// The dimension of the type that we expected
560        expected: usize,
561    },
562    /// There was a nested difference in the element type of some container type
563    ElementType {
564        /// The type constructor (list, array, int2vector, etc.)
565        ctor: String,
566        /// The difference in the element type
567        element_type: Box<DatumTypeDifference>,
568    },
569}
570
571/// Computes the difference between this datum and the specified (representation) column type.
572///
573/// See [`Datum<'a>::is_instance_of`] for just getting a yes/no answer.
574///
575/// See [`Datum<'a>::is_instance_of_sql`] for comparing `Datum`s to `SqlColumnType`s.
576fn datum_difference_with_column_type(
577    datum: &Datum<'_>,
578    column_type: &ReprColumnType,
579) -> Result<(), DatumTypeDifference> {
580    fn difference_with_scalar_type(
581        datum: &Datum<'_>,
582        scalar_type: &ReprScalarType,
583    ) -> Result<(), DatumTypeDifference> {
584        fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
585            Err(DatumTypeDifference::Mismatch {
586                // this will be redacted as appropriate
587                got_debug: format!("{got:?}"),
588                expected: expected.clone(),
589            })
590        }
591
592        if let ReprScalarType::Jsonb = scalar_type {
593            // json type checking
594            match datum {
595                Datum::Dummy => Ok(()), // allow dummys (unlike is_instance_of)
596                Datum::Null => Err(DatumTypeDifference::Null {
597                    expected: ReprScalarType::Jsonb,
598                }),
599                Datum::JsonNull
600                | Datum::False
601                | Datum::True
602                | Datum::Numeric(_)
603                | Datum::String(_) => Ok(()),
604                Datum::List(list) => {
605                    for elem in list.iter() {
606                        difference_with_scalar_type(&elem, scalar_type)?;
607                    }
608                    Ok(())
609                }
610                Datum::Map(dict) => {
611                    for (_, val) in dict.iter() {
612                        difference_with_scalar_type(&val, scalar_type)?;
613                    }
614                    Ok(())
615                }
616                _ => mismatch(datum, scalar_type),
617            }
618        } else {
619            fn element_type_difference(
620                ctor: &str,
621                element_type: DatumTypeDifference,
622            ) -> DatumTypeDifference {
623                DatumTypeDifference::ElementType {
624                    ctor: ctor.to_string(),
625                    element_type: Box::new(element_type),
626                }
627            }
628            match (datum, scalar_type) {
629                (Datum::Dummy, _) => Ok(()), // allow dummys (unlike is_instance_of)
630                (Datum::Null, _) => Err(DatumTypeDifference::Null {
631                    expected: scalar_type.clone(),
632                }),
633                (Datum::False, ReprScalarType::Bool) => Ok(()),
634                (Datum::False, _) => mismatch(datum, scalar_type),
635                (Datum::True, ReprScalarType::Bool) => Ok(()),
636                (Datum::True, _) => mismatch(datum, scalar_type),
637                (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
638                (Datum::Int16(_), _) => mismatch(datum, scalar_type),
639                (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
640                (Datum::Int32(_), _) => mismatch(datum, scalar_type),
641                (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
642                (Datum::Int64(_), _) => mismatch(datum, scalar_type),
643                (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
644                (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
645                (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
646                (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
647                (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
648                (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
649                (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
650                (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
651                (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
652                (Datum::Float32(_), _) => mismatch(datum, scalar_type),
653                (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
654                (Datum::Float64(_), _) => mismatch(datum, scalar_type),
655                (Datum::Date(_), ReprScalarType::Date) => Ok(()),
656                (Datum::Date(_), _) => mismatch(datum, scalar_type),
657                (Datum::Time(_), ReprScalarType::Time) => Ok(()),
658                (Datum::Time(_), _) => mismatch(datum, scalar_type),
659                (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
660                (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
661                (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
662                (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
663                (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
664                (Datum::Interval(_), _) => mismatch(datum, scalar_type),
665                (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
666                (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
667                (Datum::String(_), ReprScalarType::String) => Ok(()),
668                (Datum::String(_), _) => mismatch(datum, scalar_type),
669                (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
670                (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
671                (Datum::Array(array), ReprScalarType::Array(t)) => {
672                    for e in array.elements().iter() {
673                        if let Datum::Null = e {
674                            continue;
675                        }
676
677                        difference_with_scalar_type(&e, t)
678                            .map_err(|e| element_type_difference("array", e))?;
679                    }
680                    Ok(())
681                }
682                (Datum::Array(array), ReprScalarType::Int2Vector) => {
683                    if !array.has_int2vector_dims() {
684                        // Int2Vector is 1-D, but empty arrays use 0 dimensions (PostgreSQL convention,
685                        // the error message just mentions 1 dimension, but 0 dimensions are allowed when empty.
686                        return Err(DatumTypeDifference::MismatchDimensions {
687                            ctor: "int2vector".to_string(),
688                            got: array.dims().len(),
689                            expected: 1,
690                        });
691                    }
692
693                    for e in array.elements().iter() {
694                        difference_with_scalar_type(&e, &ReprScalarType::Int16)
695                            .map_err(|e| element_type_difference("int2vector", e))?;
696                    }
697
698                    Ok(())
699                }
700                (Datum::Array(_), _) => mismatch(datum, scalar_type),
701                (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
702                    for e in list.iter() {
703                        if let Datum::Null = e {
704                            continue;
705                        }
706
707                        difference_with_scalar_type(&e, element_type)
708                            .map_err(|e| element_type_difference("list", e))?;
709                    }
710                    Ok(())
711                }
712                (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
713                    let len = list.iter().count();
714                    if len != fields.len() {
715                        return Err(DatumTypeDifference::MismatchDimensions {
716                            ctor: "record".to_string(),
717                            got: len,
718                            expected: fields.len(),
719                        });
720                    }
721
722                    for (e, t) in list.iter().zip_eq(fields) {
723                        if let Datum::Null = e {
724                            if t.nullable {
725                                continue;
726                            } else {
727                                return Err(DatumTypeDifference::Null {
728                                    expected: t.scalar_type.clone(),
729                                });
730                            }
731                        }
732
733                        difference_with_scalar_type(&e, &t.scalar_type)
734                            .map_err(|e| element_type_difference("record", e))?;
735                    }
736                    Ok(())
737                }
738                (Datum::List(_), _) => mismatch(datum, scalar_type),
739                (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
740                    for (_, v) in map.iter() {
741                        if let Datum::Null = v {
742                            continue;
743                        }
744
745                        difference_with_scalar_type(&v, value_type)
746                            .map_err(|e| element_type_difference("map", e))?;
747                    }
748                    Ok(())
749                }
750                (Datum::Map(_), _) => mismatch(datum, scalar_type),
751                (Datum::JsonNull, _) => mismatch(datum, scalar_type),
752                (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
753                (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
754                (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
755                (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
756                (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
757                    match inner {
758                        None => Ok(()),
759                        Some(inner) => {
760                            if let Some(b) = inner.lower.bound {
761                                difference_with_scalar_type(&b.datum(), element_type)
762                                    .map_err(|e| element_type_difference("range", e))?;
763                            }
764                            if let Some(b) = inner.upper.bound {
765                                difference_with_scalar_type(&b.datum(), element_type)
766                                    .map_err(|e| element_type_difference("range", e))?;
767                            }
768                            Ok(())
769                        }
770                    }
771                }
772                (Datum::Range(_), _) => mismatch(datum, scalar_type),
773                (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
774                (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
775                (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
776                (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
777            }
778        }
779    }
780    if column_type.nullable {
781        if let Datum::Null = datum {
782            return Ok(());
783        }
784    }
785    difference_with_scalar_type(datum, &column_type.scalar_type)
786}
787
788fn row_difference_with_column_types<'a>(
789    source: &'a MirRelationExpr,
790    datums: &Vec<Datum<'_>>,
791    column_types: &[ReprColumnType],
792) -> Result<(), TypeError<'a>> {
793    // correct length
794    if datums.len() != column_types.len() {
795        return Err(TypeError::BadConstantRowLen {
796            source,
797            got: datums.len(),
798            expected: column_types.to_vec(),
799        });
800    }
801
802    // correct types
803    let mut mismatches = Vec::new();
804    for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
805        if let Err(e) = datum_difference_with_column_type(d, ty) {
806            mismatches.push((i, e));
807        }
808    }
809    if !mismatches.is_empty() {
810        return Err(TypeError::BadConstantRow {
811            source,
812            mismatches,
813            expected: column_types.to_vec(),
814        });
815    }
816
817    Ok(())
818}
819/// Check that the visible type of each query has not been changed
820#[derive(Debug)]
821pub struct Typecheck {
822    /// The known types of the queries so far
823    ctx: SharedContext,
824    /// Whether or not this is the first run of the transform
825    disallow_new_globals: bool,
826    /// Whether or not to be strict about join equivalences having the same nullability
827    strict_join_equivalences: bool,
828    /// Whether or not to disallow dummy values
829    disallow_dummy: bool,
830    /// Recursion guard for checked recursion
831    recursion_guard: RecursionGuard,
832}
833
834impl CheckedRecursion for Typecheck {
835    fn recursion_guard(&self) -> &RecursionGuard {
836        &self.recursion_guard
837    }
838}
839
840impl Typecheck {
841    /// Creates a typechecking consistency checking pass using a given shared context
842    pub fn new(ctx: SharedContext) -> Self {
843        Self {
844            ctx,
845            disallow_new_globals: false,
846            strict_join_equivalences: false,
847            disallow_dummy: false,
848            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
849        }
850    }
851
852    /// New non-transient global IDs will be treated as an error
853    ///
854    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
855    pub fn disallow_new_globals(mut self) -> Self {
856        self.disallow_new_globals = true;
857        self
858    }
859
860    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
861    ///
862    /// Only turn this on before `JoinImplementation`
863    pub fn strict_join_equivalences(mut self) -> Self {
864        self.strict_join_equivalences = true;
865
866        self
867    }
868
869    /// Disallow dummy values
870    pub fn disallow_dummy(mut self) -> Self {
871        self.disallow_dummy = true;
872        self
873    }
874
875    /// Returns the type of a relation expression or a type error.
876    ///
877    /// This function is careful to check validity, not just find out the type.
878    ///
879    /// It should be linear in the size of the AST.
880    ///
881    /// ??? should we also compute keys and return a `ReprRelationType`?
882    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
883    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
884    ///   reliably keep it stable across transformations.
885    pub fn typecheck<'a>(
886        &self,
887        expr: &'a MirRelationExpr,
888        ctx: &Context,
889    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
890        use MirRelationExpr::*;
891
892        self.checked_recur(|tc| match expr {
893            Constant { typ, rows } => {
894                if let Ok(rows) = rows {
895                    for (row, _id) in rows {
896                        let datums = row.unpack();
897
898                        let col_types = typ
899                            .column_types
900                            .iter()
901                            .map(ReprColumnType::from)
902                            .collect_vec();
903                        row_difference_with_column_types(
904                            expr, &datums, &col_types,
905                        )?;
906
907                        if self.disallow_dummy
908                            && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
909                        {
910                            return Err(TypeError::DisallowedDummy {
911                                source: expr,
912                            });
913                        }
914                    }
915                }
916
917                Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
918            }
919            Get { typ, id, .. } => {
920                if let Id::Global(_global_id) = id {
921                    if !ctx.contains_key(id) {
922                        // TODO(mgree) pass QueryContext through to check these types
923                        return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
924                    }
925                }
926
927                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
928                    source: expr,
929                    id: id.clone(),
930                    typ: ReprRelationType::from(typ),
931                })?;
932
933                let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
934
935                // covariant: the ascribed type must be a subtype of the actual type in the context
936                let diffs = relation_subtype_difference(&column_types, ctx_typ)
937                    .into_iter()
938                    .flat_map(|diff| diff.ignore_nullability())
939                    .collect::<Vec<_>>();
940
941                if !diffs.is_empty() {
942                    return Err(TypeError::MismatchColumns {
943                        source: expr,
944                        got: column_types,
945                        expected: ctx_typ.clone(),
946                        diffs,
947                        message: "annotation did not match context type".to_string(),
948                    });
949                }
950
951                Ok(column_types)
952            }
953            Project { input, outputs } => {
954                let t_in = tc.typecheck(input, ctx)?;
955
956                for x in outputs {
957                    if *x >= t_in.len() {
958                        return Err(TypeError::BadProject {
959                            source: expr,
960                            got: outputs.clone(),
961                            input_type: t_in,
962                        });
963                    }
964                }
965
966                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
967            }
968            Map { input, scalars } => {
969                let mut t_in = tc.typecheck(input, ctx)?;
970
971                for scalar_expr in scalars.iter() {
972                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
973
974                    if self.disallow_dummy && scalar_expr.contains_dummy() {
975                        return Err(TypeError::DisallowedDummy {
976                            source: expr,
977                        });
978                    }
979                }
980
981                Ok(t_in)
982            }
983            FlatMap { input, func, exprs } => {
984                let mut t_in = tc.typecheck(input, ctx)?;
985
986                let mut t_exprs = Vec::with_capacity(exprs.len());
987                for scalar_expr in exprs {
988                    t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
989
990                    if self.disallow_dummy && scalar_expr.contains_dummy() {
991                        return Err(TypeError::DisallowedDummy {
992                            source: expr,
993                        });
994                    }
995                }
996                // TODO(mgree) check t_exprs agrees with `func`'s input type
997
998                let t_out = func
999                    .output_type()
1000                    .column_types
1001                    .iter()
1002                    .map(ReprColumnType::from)
1003                    .collect_vec();
1004
1005                // FlatMap extends the existing columns
1006                t_in.extend(t_out);
1007                Ok(t_in)
1008            }
1009            Filter { input, predicates } => {
1010                let mut t_in = tc.typecheck(input, ctx)?;
1011
1012                // Set as nonnull any columns where null values would cause
1013                // any predicate to evaluate to null.
1014                for column in non_nullable_columns(predicates) {
1015                    t_in[column].nullable = false;
1016                }
1017
1018                for scalar_expr in predicates {
1019                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1020
1021                    // filter condition must be boolean
1022                    // ignoring nullability: null is treated as false
1023                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1024                    if t.scalar_type != ReprScalarType::Bool {
1025                        let sub = t.scalar_type.clone();
1026
1027                        return Err(TypeError::MismatchColumn {
1028                            source: expr,
1029                            got: t,
1030                            expected: ReprColumnType {
1031                                scalar_type: ReprScalarType::Bool,
1032                                nullable: true,
1033                            },
1034                            diffs: vec![ReprColumnTypeDifference::NotSubtype {
1035                                sub,
1036                                sup: ReprScalarType::Bool,
1037                            }],
1038                            message: "expected boolean condition".to_string(),
1039                        });
1040                    }
1041
1042                    if self.disallow_dummy && scalar_expr.contains_dummy() {
1043                        return Err(TypeError::DisallowedDummy {
1044                            source: expr,
1045                        });
1046                    }
1047                }
1048
1049                Ok(t_in)
1050            }
1051            Join {
1052                inputs,
1053                equivalences,
1054                implementation,
1055            } => {
1056                let mut t_in_global = Vec::new();
1057                let mut t_in_local = vec![Vec::new(); inputs.len()];
1058
1059                for (i, input) in inputs.iter().enumerate() {
1060                    let input_t = tc.typecheck(input, ctx)?;
1061                    t_in_global.extend(input_t.clone());
1062                    t_in_local[i] = input_t;
1063                }
1064
1065                for eq_class in equivalences {
1066                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1067
1068                    let mut all_nullable = true;
1069
1070                    for scalar_expr in eq_class {
1071                        // Note: the equivalences have global column references
1072                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1073
1074                        if !t_expr.nullable {
1075                            all_nullable = false;
1076                        }
1077
1078                        if let Some(t_first) = t_exprs.get(0) {
1079                            let diffs = scalar_subtype_difference(
1080                                &t_expr.scalar_type,
1081                                &t_first.scalar_type,
1082                            );
1083                            if !diffs.is_empty() {
1084                                return Err(TypeError::MismatchColumn {
1085                                    source: expr,
1086                                    got: t_expr,
1087                                    expected: t_first.clone(),
1088                                    diffs,
1089                                    message: "equivalence class members \
1090                                        have different scalar types"
1091                                        .to_string(),
1092                                });
1093                            }
1094
1095                            // equivalences may or may not match on nullability
1096                            // before JoinImplementation runs, nullability should match.
1097                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
1098                            if self.strict_join_equivalences {
1099                                if t_expr.nullable != t_first.nullable {
1100                                    let sub = t_expr.clone();
1101                                    let sup = t_first.clone();
1102
1103                                    let err = TypeError::MismatchColumn {
1104                                        source: expr,
1105                                        got: t_expr.clone(),
1106                                        expected: t_first.clone(),
1107                                        diffs: vec![
1108                                            ReprColumnTypeDifference::Nullability { sub, sup },
1109                                        ],
1110                                        message: "equivalence class members have \
1111                                            different nullability (and join \
1112                                            equivalence checking is strict)"
1113                                            .to_string(),
1114                                    };
1115
1116                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1117                                    ::tracing::debug!("{err}");
1118                                }
1119                            }
1120                        }
1121
1122                        if self.disallow_dummy && scalar_expr.contains_dummy() {
1123                            return Err(TypeError::DisallowedDummy {
1124                                source: expr,
1125                            });
1126                        }
1127
1128                        t_exprs.push(t_expr);
1129                    }
1130
1131                    if self.strict_join_equivalences && all_nullable {
1132                        let err = TypeError::BadJoinEquivalence {
1133                            source: expr,
1134                            got: t_exprs,
1135                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1136                        };
1137
1138                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1139                        ::tracing::debug!("{err}");
1140                    }
1141                }
1142
1143                // check that the join implementation is consistent
1144                match implementation {
1145                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
1146                        if let Some(key) = first_key {
1147                            for k in key {
1148                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1149                            }
1150                        }
1151
1152                        for (idx, key, _) in others {
1153                            for k in key {
1154                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1155                            }
1156                        }
1157                    }
1158                    JoinImplementation::DeltaQuery(plans) => {
1159                        for plan in plans {
1160                            for (idx, key, _) in plan {
1161                                for k in key {
1162                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1163                                }
1164                            }
1165                        }
1166                    }
1167                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1168                        let typ: Vec<ReprColumnType> = key
1169                            .iter()
1170                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1171                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1172
1173                        for row in consts {
1174                            let datums = row.unpack();
1175
1176                            row_difference_with_column_types(expr, &datums, &typ)?;
1177                        }
1178                    }
1179                    JoinImplementation::Unimplemented => (),
1180                }
1181
1182                Ok(t_in_global)
1183            }
1184            Reduce {
1185                input,
1186                group_key,
1187                aggregates,
1188                monotonic: _,
1189                expected_group_size: _,
1190            } => {
1191                let t_in = tc.typecheck(input, ctx)?;
1192
1193                let mut t_out = group_key
1194                    .iter()
1195                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1196                    .collect::<Result<Vec<_>, _>>()?;
1197
1198                    if self.disallow_dummy
1199                        && group_key
1200                            .iter()
1201                            .any(|scalar_expr| scalar_expr.contains_dummy())
1202                    {
1203                        return Err(TypeError::DisallowedDummy {
1204                            source: expr,
1205                        });
1206                    }
1207
1208                for agg in aggregates {
1209                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1210                }
1211
1212                Ok(t_out)
1213            }
1214            TopK {
1215                input,
1216                group_key,
1217                order_key,
1218                limit: _,
1219                offset: _,
1220                monotonic: _,
1221                expected_group_size: _,
1222            } => {
1223                let t_in = tc.typecheck(input, ctx)?;
1224
1225                for &k in group_key {
1226                    if k >= t_in.len() {
1227                        return Err(TypeError::BadTopKGroupKey {
1228                            source: expr,
1229                            k,
1230                            input_type: t_in,
1231                        });
1232                    }
1233                }
1234
1235                for order in order_key {
1236                    if order.column >= t_in.len() {
1237                        return Err(TypeError::BadTopKOrdering {
1238                            source: expr,
1239                            order: order.clone(),
1240                            input_type: t_in,
1241                        });
1242                    }
1243                }
1244
1245                Ok(t_in)
1246            }
1247            Negate { input } => tc.typecheck(input, ctx),
1248            Threshold { input } => tc.typecheck(input, ctx),
1249            Union { base, inputs } => {
1250                let mut t_base = tc.typecheck(base, ctx)?;
1251
1252                for input in inputs {
1253                    let t_input = tc.typecheck(input, ctx)?;
1254
1255                    let len_sub = t_base.len();
1256                    let len_sup = t_input.len();
1257                    if len_sub != len_sup {
1258                        return Err(TypeError::MismatchColumns {
1259                            source: expr,
1260                            got: t_base.clone(),
1261                            expected: t_input,
1262                            diffs: vec![ReprRelationTypeDifference::Length {
1263                                len_sub,
1264                                len_sup,
1265                            }],
1266                            message: "Union branches have different numbers of columns".to_string(),
1267                        });
1268                    }
1269
1270                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1271                        let diffs = column_union(base_col, &input_col);
1272                        if !diffs.is_empty() {
1273                            return Err(TypeError::MismatchColumn {
1274                                    source: expr,
1275                                    got: input_col,
1276                                    expected: base_col.clone(),
1277                                    diffs,
1278                                    message:
1279                                        "couldn't compute union of column types in Union"
1280                                    .to_string(),
1281                            });
1282                        }
1283
1284                    }
1285                }
1286
1287                Ok(t_base)
1288            }
1289            Let { id, value, body } => {
1290                let t_value = tc.typecheck(value, ctx)?;
1291
1292                let binding = Id::Local(*id);
1293                if ctx.contains_key(&binding) {
1294                    return Err(TypeError::Shadowing {
1295                        source: expr,
1296                        id: binding,
1297                    });
1298                }
1299
1300                let mut body_ctx = ctx.clone();
1301                body_ctx.insert(Id::Local(*id), t_value);
1302
1303                tc.typecheck(body, &body_ctx)
1304            }
1305            LetRec { ids, values, body, limits: _ } => {
1306                if ids.len() != values.len() {
1307                    return Err(TypeError::BadLetRecBindings { source: expr });
1308                }
1309
1310                // temporary hack: steal info from the Gets inside to learn the expected types
1311                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1312                let mut ctx = ctx.clone();
1313                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1314                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1315                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1316                }
1317
1318                for (id, value) in ids.iter().zip_eq(values.iter()) {
1319                    let typ = tc.typecheck(value, &ctx)?;
1320
1321                    let id = Id::Local(id.clone());
1322                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1323                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1324                            // we expect an EXACT match, but don't care about nullability
1325                            let diffs = column_union(base_col, &input_col);
1326                            if !diffs.is_empty() {
1327                                 return Err(TypeError::MismatchColumn {
1328                                        source: expr,
1329                                        got: input_col,
1330                                        expected: base_col.clone(),
1331                                        diffs,
1332                                        message:
1333                                            "couldn't compute union of column types in LetRec"
1334                                        .to_string(),
1335                                    })
1336                            }
1337                        }
1338                    } else {
1339                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1340                        ctx.insert(id, typ);
1341                    }
1342                }
1343
1344                tc.typecheck(body, &ctx)
1345            }
1346            ArrangeBy { input, keys } => {
1347                let t_in = tc.typecheck(input, ctx)?;
1348
1349                for key in keys {
1350                    for k in key {
1351                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1352                    }
1353                }
1354
1355                Ok(t_in)
1356            }
1357        })
1358    }
1359
1360    /// Traverses a term to collect the types of given ids.
1361    ///
1362    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1363    fn collect_recursive_variable_types<'a>(
1364        &self,
1365        expr: &'a MirRelationExpr,
1366        ids: &[LocalId],
1367        ctx: &mut Context,
1368    ) -> Result<(), TypeError<'a>> {
1369        use MirRelationExpr::*;
1370
1371        self.checked_recur(|tc| {
1372            match expr {
1373                Get {
1374                    id: Id::Local(id),
1375                    typ,
1376                    ..
1377                } => {
1378                    if !ids.contains(id) {
1379                        return Ok(());
1380                    }
1381
1382                    let id = Id::Local(id.clone());
1383                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1384                        let typ = typ
1385                            .column_types
1386                            .iter()
1387                            .map(ReprColumnType::from)
1388                            .collect_vec();
1389
1390                        if ctx_typ.len() != typ.len() {
1391                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1392
1393                            return Err(TypeError::MismatchColumns {
1394                                source: expr,
1395                                got: typ,
1396                                expected: ctx_typ.clone(),
1397                                diffs,
1398                                message: "environment and type annotation did not match"
1399                                    .to_string(),
1400                            });
1401                        }
1402
1403                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1404                            let diffs = column_union(base_col, &input_col);
1405                            if !diffs.is_empty() {
1406                                return Err(TypeError::MismatchColumn {
1407                                    source: expr,
1408                                    got: input_col,
1409                                    expected: base_col.clone(),
1410                                    diffs,
1411                                    message:
1412                                        "couldn't compute union of column types in Get and context"
1413                                            .to_string(),
1414                                });
1415                            }
1416                        }
1417                    } else {
1418                        ctx.insert(
1419                            id,
1420                            typ.column_types
1421                                .iter()
1422                                .map(ReprColumnType::from)
1423                                .collect_vec(),
1424                        );
1425                    }
1426                }
1427                Get {
1428                    id: Id::Global(..), ..
1429                }
1430                | Constant { .. } => (),
1431                Let { id, value, body } => {
1432                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1433
1434                    // we've shadowed the id
1435                    if ids.contains(id) {
1436                        return Err(TypeError::Shadowing {
1437                            source: expr,
1438                            id: Id::Local(*id),
1439                        });
1440                    }
1441
1442                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1443                }
1444                LetRec {
1445                    ids: inner_ids,
1446                    values,
1447                    body,
1448                    limits: _,
1449                } => {
1450                    for inner_id in inner_ids {
1451                        if ids.contains(inner_id) {
1452                            return Err(TypeError::Shadowing {
1453                                source: expr,
1454                                id: Id::Local(*inner_id),
1455                            });
1456                        }
1457                    }
1458
1459                    for value in values {
1460                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1461                    }
1462
1463                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1464                }
1465                Project { input, .. }
1466                | Map { input, .. }
1467                | FlatMap { input, .. }
1468                | Filter { input, .. }
1469                | Reduce { input, .. }
1470                | TopK { input, .. }
1471                | Negate { input }
1472                | Threshold { input }
1473                | ArrangeBy { input, .. } => {
1474                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1475                }
1476                Join { inputs, .. } => {
1477                    for input in inputs {
1478                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1479                    }
1480                }
1481                Union { base, inputs } => {
1482                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1483
1484                    for input in inputs {
1485                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1486                    }
1487                }
1488            }
1489
1490            Ok(())
1491        })
1492    }
1493
1494    fn typecheck_scalar<'a>(
1495        &self,
1496        expr: &'a MirScalarExpr,
1497        source: &'a MirRelationExpr,
1498        column_types: &[ReprColumnType],
1499    ) -> Result<ReprColumnType, TypeError<'a>> {
1500        use MirScalarExpr::*;
1501
1502        self.checked_recur(|tc| match expr {
1503            Column(i, _) => match column_types.get(*i) {
1504                Some(ty) => Ok(ty.clone()),
1505                None => Err(TypeError::NoSuchColumn {
1506                    source,
1507                    expr,
1508                    col: *i,
1509                }),
1510            },
1511            Literal(row, typ) => {
1512                let typ = ReprColumnType::from(typ);
1513                if let Ok(row) = row {
1514                    let datums = row.unpack();
1515
1516                    row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1517                }
1518
1519                Ok(typ)
1520            }
1521            CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1522            CallUnary { expr, func } => {
1523                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1524                let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1525                Ok(ReprColumnType::from(&typ_out))
1526            }
1527            CallBinary { expr1, expr2, func } => {
1528                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1529                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1530                let typ_out = func.output_type(&[
1531                    SqlColumnType::from_repr(&typ_in1),
1532                    SqlColumnType::from_repr(&typ_in2),
1533                ]);
1534                Ok(ReprColumnType::from(&typ_out))
1535            }
1536            CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1537                &func.output_type(
1538                    exprs
1539                        .iter()
1540                        .map(|e| {
1541                            tc.typecheck_scalar(e, source, column_types)
1542                                .map(|typ| SqlColumnType::from_repr(&typ))
1543                        })
1544                        .collect::<Result<Vec<_>, TypeError>>()?,
1545                ),
1546            )),
1547            If { cond, then, els } => {
1548                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1549
1550                // condition must be boolean
1551                // ignoring nullability: null is treated as false
1552                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1553                if cond_type.scalar_type != ReprScalarType::Bool {
1554                    let sub = cond_type.scalar_type.clone();
1555
1556                    return Err(TypeError::MismatchColumn {
1557                        source,
1558                        got: cond_type,
1559                        expected: ReprColumnType {
1560                            scalar_type: ReprScalarType::Bool,
1561                            nullable: true,
1562                        },
1563                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1564                            sub,
1565                            sup: ReprScalarType::Bool,
1566                        }],
1567                        message: "expected boolean condition".to_string(),
1568                    });
1569                }
1570
1571                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1572                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1573
1574                let diffs = column_union(&mut then_type, &else_type);
1575                if !diffs.is_empty() {
1576                    return Err(TypeError::MismatchColumn {
1577                        source,
1578                        got: then_type,
1579                        expected: else_type,
1580                        diffs,
1581                        message: "couldn't compute union of column types for If".to_string(),
1582                    });
1583                }
1584
1585                Ok(then_type)
1586            }
1587        })
1588    }
1589
1590    /// Typecheck an `AggregateExpr`
1591    pub fn typecheck_aggregate<'a>(
1592        &self,
1593        expr: &'a AggregateExpr,
1594        source: &'a MirRelationExpr,
1595        column_types: &[ReprColumnType],
1596    ) -> Result<ReprColumnType, TypeError<'a>> {
1597        self.checked_recur(|tc| {
1598            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1599
1600            // TODO check that t_in is actually acceptable for `func`
1601
1602            Ok(ReprColumnType::from(
1603                &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1604            ))
1605        })
1606    }
1607}
1608
1609/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1610///
1611/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1612macro_rules! type_error {
1613    ($severity:expr, $($arg:tt)+) => {{
1614        if $severity {
1615          soft_panic_or_log!($($arg)+);
1616        } else {
1617          ::tracing::debug!($($arg)+);
1618        }
1619    }}
1620}
1621
1622impl crate::Transform for Typecheck {
1623    fn name(&self) -> &'static str {
1624        "Typecheck"
1625    }
1626
1627    fn actually_perform_transform(
1628        &self,
1629        relation: &mut MirRelationExpr,
1630        transform_ctx: &mut crate::TransformCtx,
1631    ) -> Result<(), crate::TransformError> {
1632        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1633
1634        let expected = transform_ctx
1635            .global_id
1636            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1637
1638        if let Some(id) = transform_ctx.global_id {
1639            if self.disallow_new_globals
1640                && expected.is_none()
1641                && transform_ctx.global_id.is_some()
1642                && !id.is_transient()
1643            {
1644                type_error!(
1645                    false, // not severe
1646                    "type warning: new non-transient global id {id}\n{}",
1647                    relation.pretty()
1648                );
1649            }
1650        }
1651
1652        let got = self.typecheck(relation, &typecheck_ctx);
1653
1654        let humanizer = mz_repr::explain::DummyHumanizer;
1655
1656        match (got, expected) {
1657            (Ok(got), Some(expected)) => {
1658                let id = transform_ctx.global_id.unwrap();
1659
1660                // contravariant: global types can be updated
1661                let diffs = relation_subtype_difference(expected, &got);
1662                if !diffs.is_empty() {
1663                    // SEVERE only if got and expected have true differences, not just nullability
1664                    let severity = diffs
1665                        .iter()
1666                        .any(|diff| diff.clone().ignore_nullability().is_some());
1667
1668                    let err = TypeError::MismatchColumns {
1669                        source: relation,
1670                        got,
1671                        expected: expected.clone(),
1672                        diffs,
1673                        message: format!(
1674                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1675                        ),
1676                    };
1677
1678                    type_error!(severity, "type error in known global id {id}:\n{err}");
1679                }
1680            }
1681            (Ok(got), None) => {
1682                if let Some(id) = transform_ctx.global_id {
1683                    typecheck_ctx.insert(Id::Global(id), got);
1684                }
1685            }
1686            (Err(err), _) => {
1687                let (expected, binding) = match expected {
1688                    Some(expected) => {
1689                        let id = transform_ctx.global_id.unwrap();
1690                        (
1691                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1692                            format!("known global id {id}"),
1693                        )
1694                    }
1695                    None => ("".to_string(), "transient query".to_string()),
1696                };
1697
1698                type_error!(
1699                    true, // SEVERE: the transformed code is inconsistent
1700                    "type error in {binding}:\n{err}\n{expected}{}",
1701                    relation.pretty()
1702                );
1703            }
1704        }
1705
1706        Ok(())
1707    }
1708}
1709
1710/// Prints a type prettily with a given `ExprHumanizer`
1711pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1712where
1713    H: ExprHumanizer,
1714{
1715    let mut s = String::with_capacity(2 + 3 * cols.len());
1716
1717    s.push('(');
1718
1719    let mut it = cols.iter().peekable();
1720    while let Some(col) = it.next() {
1721        s.push_str(&humanizer.humanize_column_type_repr(col, false));
1722
1723        if it.peek().is_some() {
1724            s.push_str(", ");
1725        }
1726    }
1727
1728    s.push(')');
1729
1730    s
1731}
1732
1733impl ReprRelationTypeDifference {
1734    /// Pretty prints a type difference
1735    ///
1736    /// Always indents two spaces
1737    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1738    where
1739        H: ExprHumanizer,
1740    {
1741        use ReprRelationTypeDifference::*;
1742        match self {
1743            Length { len_sub, len_sup } => {
1744                writeln!(
1745                    f,
1746                    "  number of columns do not match ({len_sub} != {len_sup})"
1747                )
1748            }
1749            Column { col, diff } => {
1750                writeln!(f, "  column {col} differs:")?;
1751                diff.humanize(4, h, f)
1752            }
1753        }
1754    }
1755}
1756
1757impl ReprColumnTypeDifference {
1758    /// Pretty prints a type difference at a given indentation level
1759    pub fn humanize<H>(
1760        &self,
1761        indent: usize,
1762        h: &H,
1763        f: &mut std::fmt::Formatter<'_>,
1764    ) -> std::fmt::Result
1765    where
1766        H: ExprHumanizer,
1767    {
1768        use ReprColumnTypeDifference::*;
1769
1770        // indent
1771        write!(f, "{:indent$}", "")?;
1772
1773        match self {
1774            NotSubtype { sub, sup } => {
1775                let sub = h.humanize_scalar_type_repr(sub, false);
1776                let sup = h.humanize_scalar_type_repr(sup, false);
1777
1778                writeln!(f, "{sub} is a not a subtype of {sup}")
1779            }
1780            Nullability { sub, sup } => {
1781                let sub = h.humanize_column_type_repr(sub, false);
1782                let sup = h.humanize_column_type_repr(sup, false);
1783
1784                writeln!(f, "{sub} is nullable but {sup} is not")
1785            }
1786            ElementType { ctor, element_type } => {
1787                writeln!(f, "{ctor} element types differ:")?;
1788
1789                element_type.humanize(indent + 2, h, f)
1790            }
1791            RecordMissingFields { missing } => {
1792                write!(f, "missing column fields:")?;
1793                for col in missing {
1794                    write!(f, " {col}")?;
1795                }
1796                f.write_char('\n')
1797            }
1798            RecordFields { fields } => {
1799                writeln!(f, "{} record fields differ:", fields.len())?;
1800
1801                for (i, diff) in fields.iter().enumerate() {
1802                    writeln!(f, "{:indent$}  field {i}:", "")?;
1803                    diff.humanize(indent + 4, h, f)?;
1804                }
1805                Ok(())
1806            }
1807        }
1808    }
1809}
1810
1811impl DatumTypeDifference {
1812    /// Pretty prints a type difference at a given indentation level
1813    pub fn humanize<H>(
1814        &self,
1815        indent: usize,
1816        h: &H,
1817        f: &mut std::fmt::Formatter<'_>,
1818    ) -> std::fmt::Result
1819    where
1820        H: ExprHumanizer,
1821    {
1822        // indent
1823        write!(f, "{:indent$}", "")?;
1824
1825        match self {
1826            DatumTypeDifference::Null { expected } => {
1827                let expected = h.humanize_scalar_type_repr(expected, false);
1828                writeln!(
1829                    f,
1830                    "unexpected null, expected representation type {expected}"
1831                )?
1832            }
1833            DatumTypeDifference::Mismatch {
1834                got_debug,
1835                expected,
1836            } => {
1837                let expected = h.humanize_scalar_type_repr(expected, false);
1838                // NB `got_debug` will be redacted as appropriate
1839                writeln!(
1840                    f,
1841                    "got datum {got_debug}, expected representation type {expected}"
1842                )?;
1843            }
1844            DatumTypeDifference::MismatchDimensions {
1845                ctor,
1846                got,
1847                expected,
1848            } => {
1849                writeln!(
1850                    f,
1851                    "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1852                )?;
1853            }
1854            DatumTypeDifference::ElementType { ctor, element_type } => {
1855                writeln!(f, "{ctor} element types differ:")?;
1856                element_type.humanize(indent + 4, h, f)?;
1857            }
1858        }
1859
1860        Ok(())
1861    }
1862}
1863
1864/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1865#[allow(missing_debug_implementations)]
1866pub struct TypeErrorHumanizer<'a, 'b, H>
1867where
1868    H: ExprHumanizer,
1869{
1870    err: &'a TypeError<'a>,
1871    humanizer: &'b H,
1872}
1873
1874impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1875where
1876    H: ExprHumanizer,
1877{
1878    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1879    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1880        Self { err, humanizer }
1881    }
1882}
1883
1884impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1885where
1886    H: ExprHumanizer,
1887{
1888    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1889        self.err.humanize(self.humanizer, f)
1890    }
1891}
1892
1893impl<'a> std::fmt::Display for TypeError<'a> {
1894    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1895        TypeErrorHumanizer {
1896            err: self,
1897            humanizer: &DummyHumanizer,
1898        }
1899        .fmt(f)
1900    }
1901}
1902
1903impl<'a> TypeError<'a> {
1904    /// The source of the type error
1905    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1906        use TypeError::*;
1907        match self {
1908            Unbound { source, .. }
1909            | NoSuchColumn { source, .. }
1910            | MismatchColumn { source, .. }
1911            | MismatchColumns { source, .. }
1912            | BadConstantRowLen { source, .. }
1913            | BadConstantRow { source, .. }
1914            | BadProject { source, .. }
1915            | BadJoinEquivalence { source, .. }
1916            | BadTopKGroupKey { source, .. }
1917            | BadTopKOrdering { source, .. }
1918            | BadLetRecBindings { source }
1919            | Shadowing { source, .. }
1920            | DisallowedDummy { source, .. } => Some(source),
1921            Recursion { .. } => None,
1922        }
1923    }
1924
1925    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1926    where
1927        H: ExprHumanizer,
1928    {
1929        if let Some(source) = self.source() {
1930            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1931        }
1932
1933        use TypeError::*;
1934        match self {
1935            Unbound { source: _, id, typ } => {
1936                let typ = columns_pretty(&typ.column_types, humanizer);
1937                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1938            }
1939            NoSuchColumn {
1940                source: _,
1941                expr,
1942                col,
1943            } => writeln!(f, "{expr} references non-existent column {col}")?,
1944            MismatchColumn {
1945                source: _,
1946                got,
1947                expected,
1948                diffs,
1949                message,
1950            } => {
1951                let got = humanizer.humanize_column_type_repr(got, false);
1952                let expected = humanizer.humanize_column_type_repr(expected, false);
1953                writeln!(
1954                    f,
1955                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1956                )?;
1957
1958                for diff in diffs {
1959                    diff.humanize(2, humanizer, f)?;
1960                }
1961            }
1962            MismatchColumns {
1963                source: _,
1964                got,
1965                expected,
1966                diffs,
1967                message,
1968            } => {
1969                let got = columns_pretty(got, humanizer);
1970                let expected = columns_pretty(expected, humanizer);
1971
1972                writeln!(
1973                    f,
1974                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1975                )?;
1976
1977                for diff in diffs {
1978                    diff.humanize(humanizer, f)?;
1979                }
1980            }
1981            BadConstantRowLen {
1982                source: _,
1983                got,
1984                expected,
1985            } => {
1986                let expected = columns_pretty(expected, humanizer);
1987                writeln!(
1988                    f,
1989                    "bad constant row\n      row has length {got}\nexpected row of type {expected}"
1990                )?
1991            }
1992            BadConstantRow {
1993                source: _,
1994                mismatches,
1995                expected,
1996            } => {
1997                let expected = columns_pretty(expected, humanizer);
1998
1999                let num_mismatches = mismatches.len();
2000                let plural = if num_mismatches == 1 { "" } else { "es" };
2001                writeln!(
2002                    f,
2003                    "bad constant row\n      got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
2004                )?;
2005
2006                if num_mismatches > 0 {
2007                    writeln!(f, "")?;
2008                    for (col, diff) in mismatches.iter() {
2009                        writeln!(f, "      column #{col}:")?;
2010                        diff.humanize(8, humanizer, f)?;
2011                    }
2012                }
2013            }
2014            BadProject {
2015                source: _,
2016                got,
2017                input_type,
2018            } => {
2019                let input_type = columns_pretty(input_type, humanizer);
2020
2021                writeln!(
2022                    f,
2023                    "projection of non-existant columns {got:?} from type {input_type}"
2024                )?
2025            }
2026            BadJoinEquivalence {
2027                source: _,
2028                got,
2029                message,
2030            } => {
2031                let got = columns_pretty(got, humanizer);
2032
2033                writeln!(f, "bad join equivalence {got}: {message}")?
2034            }
2035            BadTopKGroupKey {
2036                source: _,
2037                k,
2038                input_type,
2039            } => {
2040                let input_type = columns_pretty(input_type, humanizer);
2041
2042                writeln!(
2043                    f,
2044                    "TopK group key component references invalid column {k} in columns: {input_type}"
2045                )?
2046            }
2047            BadTopKOrdering {
2048                source: _,
2049                order,
2050                input_type,
2051            } => {
2052                let col = order.column;
2053                let num_cols = input_type.len();
2054                let are = if num_cols == 1 { "is" } else { "are" };
2055                let s = if num_cols == 1 { "" } else { "s" };
2056                let input_type = columns_pretty(input_type, humanizer);
2057
2058                // TODO(cloud#8196)
2059                let mode = HumanizedExplain::new(false);
2060                let order = mode.expr(order, None);
2061
2062                writeln!(
2063                    f,
2064                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2065                )?
2066            }
2067            BadLetRecBindings { source: _ } => {
2068                writeln!(f, "LetRec ids and definitions don't line up")?
2069            }
2070            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2071            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2072            Recursion { error } => writeln!(f, "{error}")?,
2073        }
2074
2075        Ok(())
2076    }
2077}
2078
2079#[cfg(test)]
2080mod tests {
2081    use mz_ore::{assert_err, assert_ok};
2082    use mz_repr::{arb_datum, arb_datum_for_column};
2083    use proptest::prelude::*;
2084
2085    use super::*;
2086
2087    #[mz_ore::test]
2088    fn test_datum_type_difference() {
2089        let datum = Datum::Int16(1);
2090
2091        assert_ok!(datum_difference_with_column_type(
2092            &datum,
2093            &ReprColumnType {
2094                scalar_type: ReprScalarType::Int16,
2095                nullable: true,
2096            }
2097        ));
2098
2099        assert_err!(datum_difference_with_column_type(
2100            &datum,
2101            &ReprColumnType {
2102                scalar_type: ReprScalarType::Int32,
2103                nullable: false,
2104            }
2105        ));
2106    }
2107
2108    proptest! {
2109        #![proptest_config(ProptestConfig {
2110            cases: 5000,
2111            max_global_rejects: 2500,
2112            ..Default::default()
2113        })]
2114        #[mz_ore::test]
2115        #[cfg_attr(miri, ignore)]
2116        fn datum_type_difference_with_instance_of_on_valid_data(
2117            (src, datum) in any::<SqlColumnType>()
2118                .prop_flat_map(|src| {
2119                    let datum = arb_datum_for_column(src.clone());
2120                    (Just(src), datum)
2121                })
2122        ) {
2123            let typ = ReprColumnType::from(&src);
2124            let datum = Datum::from(&datum);
2125
2126            if datum.contains_dummy() {
2127                return Err(TestCaseError::reject("datum contains a dummy"));
2128            }
2129
2130            let diff = datum_difference_with_column_type(&datum, &typ);
2131            if datum.is_instance_of(&typ) {
2132                assert_ok!(diff);
2133            } else {
2134                assert_err!(diff);
2135            }
2136        }
2137    }
2138
2139    proptest! {
2140        // We run many cases because the data are _random_, and we want to be sure
2141        // that we have covered sufficient cases. We drop dummy data (behavior is different!) so we allow many more global rejects.
2142        #![proptest_config(ProptestConfig::with_cases(10000))]
2143        #[mz_ore::test]
2144        #[cfg_attr(miri, ignore)]
2145        fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2146            src in any::<SqlColumnType>(),
2147            datum in arb_datum(false),
2148        ) {
2149            let typ = ReprColumnType::from(&src);
2150            let datum = Datum::from(&datum);
2151
2152            assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2153
2154            let diff = datum_difference_with_column_type(&datum, &typ);
2155            if datum.is_instance_of(&typ) {
2156                assert_ok!(diff);
2157            } else {
2158                assert_err!(diff);
2159            }
2160        }
2161    }
2162
2163    #[mz_ore::test]
2164    fn datum_type_difference_github_10039() {
2165        let typ = ReprColumnType {
2166            scalar_type: ReprScalarType::Record {
2167                fields: Box::new([ReprColumnType {
2168                    scalar_type: ReprScalarType::UInt32,
2169                    nullable: false,
2170                }]),
2171            },
2172            nullable: false,
2173        };
2174
2175        let mut row = mz_repr::Row::default();
2176        row.packer()
2177            .push_list(std::iter::once(mz_repr::Datum::Null));
2178        let datum = row.unpack_first();
2179
2180        assert!(!datum.is_instance_of(&typ));
2181        let diff = datum_difference_with_column_type(&datum, &typ);
2182        assert_err!(diff);
2183    }
2184}