mz_transform/
reprtypecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27    ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28    SqlColumnType,
29};
30
31/// Typechecking contexts as shared by various typechecking passes.
32///
33/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
34/// Shared contexts help catch consistency issues.
35pub type SharedContext = Arc<Mutex<Context>>;
36
37/// Generates an empty context
38pub fn empty_context() -> SharedContext {
39    Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42/// The possible forms of inconsistency/errors discovered during typechecking.
43///
44/// Every variant has a `source` field identifying the MIR term that is home
45/// to the error (though not necessarily the root cause of the error).
46#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48    /// Unbound identifiers (local or global)
49    Unbound {
50        /// Expression with the bug
51        source: &'a MirRelationExpr,
52        /// The (unbound) identifier referenced
53        id: Id,
54        /// The type `id` was expected to have
55        typ: ReprRelationType,
56    },
57    /// Dereference of a non-existent column
58    NoSuchColumn {
59        /// Expression with the bug
60        source: &'a MirRelationExpr,
61        /// Scalar expression that references an invalid column
62        expr: &'a MirScalarExpr,
63        /// The invalid column referenced
64        col: usize,
65    },
66    /// A single column type does not match
67    MismatchColumn {
68        /// Expression with the bug
69        source: &'a MirRelationExpr,
70        /// The column type we found (`sub` type)
71        got: ReprColumnType,
72        /// The column type we expected (`sup` type)
73        expected: ReprColumnType,
74        /// The difference between these types
75        diffs: Vec<ReprColumnTypeDifference>,
76        /// An explanatory message
77        message: String,
78    },
79    /// Relation column types do not match
80    MismatchColumns {
81        /// Expression with the bug
82        source: &'a MirRelationExpr,
83        /// The column types we found (`sub` type)
84        got: Vec<ReprColumnType>,
85        /// The column types we expected (`sup` type)
86        expected: Vec<ReprColumnType>,
87        /// The difference between these types
88        diffs: Vec<ReprRelationTypeDifference>,
89        /// An explanatory message
90        message: String,
91    },
92    /// A constant row does not have the correct type
93    BadConstantRowLen {
94        /// Expression with the bug
95        source: &'a MirRelationExpr,
96        /// A constant row
97        got: usize,
98        /// The expected type (which that row does not have)
99        expected: Vec<ReprColumnType>,
100    },
101    /// A constant row does not have the correct type
102    BadConstantRow {
103        /// Expression with the bug
104        source: &'a MirRelationExpr,
105        /// The columns that mismatched, with a debug rendering of the datum we got and the expected type
106        mismatches: Vec<(usize, DatumTypeDifference)>,
107        /// The expected type (which that row does not have)
108        expected: Vec<ReprColumnType>,
109        // TODO(mgree) with a good way to get the type of a Datum, we could give a detailed diff here
110        // this is difficult for empty structures where we may not know the element type
111    },
112    /// Projection of a non-existent column
113    BadProject {
114        /// Expression with the bug
115        source: &'a MirRelationExpr,
116        /// The column projected
117        got: Vec<usize>,
118        /// The input columns (which don't have that column)
119        input_type: Vec<ReprColumnType>,
120    },
121    /// An equivalence class in a join was malformed
122    BadJoinEquivalence {
123        /// Expression with the bug
124        source: &'a MirRelationExpr,
125        /// The join equivalences
126        got: Vec<ReprColumnType>,
127        /// The problem with the join equivalences
128        message: String,
129    },
130    /// TopK grouping by non-existent column
131    BadTopKGroupKey {
132        /// Expression with the bug
133        source: &'a MirRelationExpr,
134        /// The bad column reference in the group key
135        k: usize,
136        /// The input columns (which don't have that column)
137        input_type: Vec<ReprColumnType>,
138    },
139    /// TopK ordering by non-existent column
140    BadTopKOrdering {
141        /// Expression with the bug
142        source: &'a MirRelationExpr,
143        /// The ordering used
144        order: ColumnOrder,
145        /// The input columns (which don't work for that ordering)
146        input_type: Vec<ReprColumnType>,
147    },
148    /// LetRec bindings are malformed
149    BadLetRecBindings {
150        /// Expression with the bug
151        source: &'a MirRelationExpr,
152    },
153    /// Local identifiers are shadowed
154    Shadowing {
155        /// Expression with the bug
156        source: &'a MirRelationExpr,
157        /// The id that was shadowed
158        id: Id,
159    },
160    /// Recursion depth exceeded
161    Recursion {
162        /// The error that aborted recursion
163        error: RecursionLimitError,
164    },
165    /// A dummy value was found
166    DisallowedDummy {
167        /// The expression with the dummy value
168        source: &'a MirRelationExpr,
169    },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173    fn from(error: RecursionLimitError) -> Self {
174        TypeError::Recursion { error }
175    }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180/// Characterizes differences between relation types
181///
182/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
183#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185    /// `sub` and `sup` don't have the same number of columns
186    Length {
187        /// Length of `sub`
188        len_sub: usize,
189        /// Length of `sup`
190        len_sup: usize,
191    },
192    /// `sub` and `sup` differ at the indicated column
193    Column {
194        /// The column at which `sub` and `sup` differ
195        col: usize,
196        /// The difference between `sub` and `sup`
197        diff: ReprColumnTypeDifference,
198    },
199}
200
201/// Characterizes differences between individual column types
202///
203/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
204/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
205#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
208    NotSubtype {
209        /// Would-be subtype
210        sub: ReprScalarType,
211        /// Would-be supertype
212        sup: ReprScalarType,
213    },
214    /// `sub` was nullable but `sup` was not
215    Nullability {
216        /// Would-be subtype
217        sub: ReprColumnType,
218        /// Would-be supertype
219        sup: ReprColumnType,
220    },
221    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
222    ElementType {
223        /// The type constructor (list, array, etc.)
224        ctor: String,
225        /// The difference in the element type
226        element_type: Box<ReprColumnTypeDifference>,
227    },
228    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
229    RecordMissingFields {
230        /// The missing fields
231        missing: Vec<ColumnName>,
232    },
233    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
234    RecordFields {
235        /// The differences, by field
236        fields: Vec<ReprColumnTypeDifference>,
237    },
238}
239
240impl ReprRelationTypeDifference {
241    /// Returns the same type difference, but ignoring nullability
242    ///
243    /// Returns `None` when _all_ of the differences are due to nullability
244    pub fn ignore_nullability(self) -> Option<Self> {
245        use ReprRelationTypeDifference::*;
246
247        match self {
248            Length { .. } => Some(self),
249            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250        }
251    }
252}
253
254impl ReprColumnTypeDifference {
255    /// Returns the same type difference, but ignoring nullability
256    ///
257    /// Returns `None` when _all_ of the differences are due to nullability
258    pub fn ignore_nullability(self) -> Option<Self> {
259        use ReprColumnTypeDifference::*;
260
261        match self {
262            Nullability { .. } => None,
263            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264            ElementType { ctor, element_type } => {
265                element_type
266                    .ignore_nullability()
267                    .map(|element_type| ElementType {
268                        ctor,
269                        element_type: Box::new(element_type),
270                    })
271            }
272            RecordFields { fields } => {
273                let fields = fields
274                    .into_iter()
275                    .flat_map(|diff| diff.ignore_nullability())
276                    .collect::<Vec<_>>();
277
278                if fields.is_empty() {
279                    None
280                } else {
281                    Some(RecordFields { fields })
282                }
283            }
284        }
285    }
286}
287
288/// Returns a list of differences that make `sub` not a subtype of `sup`
289///
290/// This function returns an empty list when `sub` is a subtype of `sup`
291pub fn relation_subtype_difference(
292    sub: &[ReprColumnType],
293    sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295    let mut diffs = Vec::new();
296
297    if sub.len() != sup.len() {
298        diffs.push(ReprRelationTypeDifference::Length {
299            len_sub: sub.len(),
300            len_sup: sup.len(),
301        });
302
303        // TODO(mgree) we could do an edit-distance computation to report more errors
304        return diffs;
305    }
306
307    diffs.extend(
308        sub.iter()
309            .zip_eq(sup.iter())
310            .enumerate()
311            .flat_map(|(col, (sub_ty, sup_ty))| {
312                column_subtype_difference(sub_ty, sup_ty)
313                    .into_iter()
314                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315            }),
316    );
317
318    diffs
319}
320
321/// Returns a list of differences that make `sub` not a subtype of `sup`
322///
323/// This function returns an empty list when `sub` is a subtype of `sup`
324pub fn column_subtype_difference(
325    sub: &ReprColumnType,
326    sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330    if sub.nullable && !sup.nullable {
331        diffs.push(ReprColumnTypeDifference::Nullability {
332            sub: sub.clone(),
333            sup: sup.clone(),
334        });
335    }
336
337    diffs
338}
339
340/// Returns a list of differences that make `sub` not a subtype of `sup`
341///
342/// This function returns an empty list when `sub` is a subtype of `sup`
343pub fn scalar_subtype_difference(
344    sub: &ReprScalarType,
345    sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347    use ReprScalarType::*;
348
349    let mut diffs = Vec::new();
350
351    match (sub, sup) {
352        (
353            List {
354                element_type: sub_elt,
355                ..
356            },
357            List {
358                element_type: sup_elt,
359                ..
360            },
361        )
362        | (
363            Map {
364                value_type: sub_elt,
365                ..
366            },
367            Map {
368                value_type: sup_elt,
369                ..
370            },
371        )
372        | (
373            Range {
374                element_type: sub_elt,
375                ..
376            },
377            Range {
378                element_type: sup_elt,
379                ..
380            },
381        )
382        | (Array(sub_elt), Array(sup_elt)) => {
383            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384            diffs.extend(
385                scalar_subtype_difference(sub_elt, sup_elt)
386                    .into_iter()
387                    .map(|diff| ReprColumnTypeDifference::ElementType {
388                        ctor: ctor.clone(),
389                        element_type: Box::new(diff),
390                    }),
391            );
392        }
393        (
394            Record {
395                fields: sub_fields, ..
396            },
397            Record {
398                fields: sup_fields, ..
399            },
400        ) => {
401            if sub_fields.len() != sup_fields.len() {
402                diffs.push(ReprColumnTypeDifference::NotSubtype {
403                    sub: sub.clone(),
404                    sup: sup.clone(),
405                });
406                return diffs;
407            }
408
409            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411            }
412        }
413        (_, _) => {
414            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415                diffs.push(ReprColumnTypeDifference::NotSubtype {
416                    sub: sub.clone(),
417                    sup: sup.clone(),
418                })
419            }
420        }
421    };
422
423    diffs
424}
425
426/// Unions `other` into `typ`, returning a list of differences on failure
427///
428/// This function returns an empty list when `typ` and `other` are a union
429pub fn scalar_union(
430    typ: &mut ReprScalarType,
431    other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433    use ReprScalarType::*;
434
435    let mut diffs = Vec::new();
436
437    // precomputing to appease the borrow checker
438    let ctor = ReprScalarBaseType::from(&*typ);
439    match (typ, other) {
440        (
441            List {
442                element_type: typ_elt,
443            },
444            List {
445                element_type: other_elt,
446            },
447        )
448        | (
449            Map {
450                value_type: typ_elt,
451            },
452            Map {
453                value_type: other_elt,
454            },
455        )
456        | (
457            Range {
458                element_type: typ_elt,
459            },
460            Range {
461                element_type: other_elt,
462            },
463        )
464        | (Array(typ_elt), Array(other_elt)) => {
465            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466            diffs.extend(
467                res.into_iter()
468                    .map(|diff| ReprColumnTypeDifference::ElementType {
469                        ctor: format!("{ctor:?}"),
470                        element_type: Box::new(diff),
471                    }),
472            );
473        }
474        (
475            Record { fields: typ_fields },
476            Record {
477                fields: other_fields,
478            },
479        ) => {
480            if typ_fields.len() != other_fields.len() {
481                diffs.push(ReprColumnTypeDifference::NotSubtype {
482                    sub: ReprScalarType::Record {
483                        fields: typ_fields.clone(),
484                    },
485                    sup: other.clone(),
486                });
487                return diffs;
488            }
489
490            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491                diffs.extend(column_union(typ_ty, other_ty));
492            }
493        }
494        (typ, _) => {
495            if ctor != ReprScalarBaseType::from(other) {
496                diffs.push(ReprColumnTypeDifference::NotSubtype {
497                    sub: typ.clone(),
498                    sup: other.clone(),
499                })
500            }
501        }
502    };
503
504    diffs
505}
506
507/// Unions `other` into `typ`, returning a list of differences on failure
508///
509/// This function returns an empty list when `typ` and `other` are a union
510pub fn column_union(
511    typ: &mut ReprColumnType,
512    other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516    if diffs.is_empty() {
517        typ.nullable |= other.nullable;
518    }
519
520    diffs
521}
522
523/// Returns true when it is safe to treat a `sub` row as an `sup` row
524///
525/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
526/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
527pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528    if sub.len() != sup.len() {
529        return false;
530    }
531
532    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534    })
535}
536
537/// Characterizes how a Datum differs from a ReprColumnType
538#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540    /// Datum was null, but expected a non-null value (i.e., a ReprScalarType)
541    Null,
542    /// Datum's kind doesn't match the expected ReprScalarType
543    Mismatch {
544        /// The debug rendering of the Datum that we got (we don't store the Datum itself to avoid borrowing issues)
545        /// This debug rendering includes the DatumKind.
546        got_debug: String,
547        /// The representation scalar type that we expected
548        expected: ReprScalarType,
549    },
550    /// Datum's dimensions didn't match the type's dimensions
551    MismatchDimensions {
552        /// The type constructor (list, array, int2vector, etc.)
553        ctor: String,
554        /// The dimension of the datum that we got
555        got: usize,
556        /// The dimension of the type that we expected
557        expected: usize,
558    },
559    /// There was a nested difference in the element type of some container type
560    ElementType {
561        /// The type constructor (list, array, int2vector, etc.)
562        ctor: String,
563        /// The difference in the element type
564        element_type: Box<DatumTypeDifference>,
565    },
566}
567
568/// Computes the difference between this datum and the specified (representation) column type.
569///
570/// See [`Datum<'a>::is_instance_of`] for just getting a yes/no answer.
571///
572/// See [`Datum<'a>::is_instance_of_sql`] for comparing `Datum`s to `SqlColumnType`s.
573fn datum_difference_with_column_type(
574    datum: &Datum<'_>,
575    column_type: &ReprColumnType,
576) -> Result<(), DatumTypeDifference> {
577    fn difference_with_scalar_type(
578        datum: &Datum<'_>,
579        scalar_type: &ReprScalarType,
580    ) -> Result<(), DatumTypeDifference> {
581        fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
582            Err(DatumTypeDifference::Mismatch {
583                // this will be redacted as appropriate
584                got_debug: format!("{got:?}"),
585                expected: expected.clone(),
586            })
587        }
588
589        if let ReprScalarType::Jsonb = scalar_type {
590            // json type checking
591            match datum {
592                Datum::Dummy => Ok(()), // allow dummys (unlike is_instance_of)
593                Datum::Null => Err(DatumTypeDifference::Null),
594                Datum::JsonNull
595                | Datum::False
596                | Datum::True
597                | Datum::Numeric(_)
598                | Datum::String(_) => Ok(()),
599                Datum::List(list) => {
600                    for elem in list.iter() {
601                        difference_with_scalar_type(&elem, scalar_type)?;
602                    }
603                    Ok(())
604                }
605                Datum::Map(dict) => {
606                    for (_, val) in dict.iter() {
607                        difference_with_scalar_type(&val, scalar_type)?;
608                    }
609                    Ok(())
610                }
611                _ => mismatch(datum, scalar_type),
612            }
613        } else {
614            fn element_type_difference(
615                ctor: &str,
616                element_type: DatumTypeDifference,
617            ) -> DatumTypeDifference {
618                DatumTypeDifference::ElementType {
619                    ctor: ctor.to_string(),
620                    element_type: Box::new(element_type),
621                }
622            }
623            match (datum, scalar_type) {
624                (Datum::Dummy, _) => Ok(()), // allow dummys (unlike is_instance_of)
625                (Datum::Null, _) => Err(DatumTypeDifference::Null),
626                (Datum::False, ReprScalarType::Bool) => Ok(()),
627                (Datum::False, _) => mismatch(datum, scalar_type),
628                (Datum::True, ReprScalarType::Bool) => Ok(()),
629                (Datum::True, _) => mismatch(datum, scalar_type),
630                (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
631                (Datum::Int16(_), _) => mismatch(datum, scalar_type),
632                (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
633                (Datum::Int32(_), _) => mismatch(datum, scalar_type),
634                (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
635                (Datum::Int64(_), _) => mismatch(datum, scalar_type),
636                (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
637                (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
638                (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
639                (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
640                (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
641                (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
642                (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
643                (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
644                (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
645                (Datum::Float32(_), _) => mismatch(datum, scalar_type),
646                (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
647                (Datum::Float64(_), _) => mismatch(datum, scalar_type),
648                (Datum::Date(_), ReprScalarType::Date) => Ok(()),
649                (Datum::Date(_), _) => mismatch(datum, scalar_type),
650                (Datum::Time(_), ReprScalarType::Time) => Ok(()),
651                (Datum::Time(_), _) => mismatch(datum, scalar_type),
652                (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
653                (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
654                (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
655                (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
656                (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
657                (Datum::Interval(_), _) => mismatch(datum, scalar_type),
658                (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
659                (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
660                (Datum::String(_), ReprScalarType::String) => Ok(()),
661                (Datum::String(_), _) => mismatch(datum, scalar_type),
662                (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
663                (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
664                (Datum::Array(array), ReprScalarType::Array(t)) => {
665                    for e in array.elements().iter() {
666                        if let Datum::Null = e {
667                            continue;
668                        }
669
670                        difference_with_scalar_type(&e, t)
671                            .map_err(|e| element_type_difference("array", e))?;
672                    }
673                    Ok(())
674                }
675                (Datum::Array(array), ReprScalarType::Int2Vector) => {
676                    if array.dims().len() != 1 {
677                        return Err(DatumTypeDifference::MismatchDimensions {
678                            ctor: "int2vector".to_string(),
679                            got: array.dims().len(),
680                            expected: 1,
681                        });
682                    }
683
684                    for e in array.elements().iter() {
685                        difference_with_scalar_type(&e, &ReprScalarType::Int16)
686                            .map_err(|e| element_type_difference("int2vector", e))?;
687                    }
688
689                    Ok(())
690                }
691                (Datum::Array(_), _) => mismatch(datum, scalar_type),
692                (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
693                    for e in list.iter() {
694                        if let Datum::Null = e {
695                            continue;
696                        }
697
698                        difference_with_scalar_type(&e, element_type)
699                            .map_err(|e| element_type_difference("list", e))?;
700                    }
701                    Ok(())
702                }
703                (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
704                    let len = list.iter().count();
705                    if len != fields.len() {
706                        return Err(DatumTypeDifference::MismatchDimensions {
707                            ctor: "record".to_string(),
708                            got: len,
709                            expected: fields.len(),
710                        });
711                    }
712
713                    for (e, t) in list.iter().zip_eq(fields) {
714                        if let Datum::Null = e {
715                            continue;
716                        }
717
718                        difference_with_scalar_type(&e, &t.scalar_type)
719                            .map_err(|e| element_type_difference("record", e))?;
720                    }
721                    Ok(())
722                }
723                (Datum::List(_), _) => mismatch(datum, scalar_type),
724                (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
725                    for (_, v) in map.iter() {
726                        if let Datum::Null = v {
727                            continue;
728                        }
729
730                        difference_with_scalar_type(&v, value_type)
731                            .map_err(|e| element_type_difference("map", e))?;
732                    }
733                    Ok(())
734                }
735                (Datum::Map(_), _) => mismatch(datum, scalar_type),
736                (Datum::JsonNull, _) => mismatch(datum, scalar_type),
737                (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
738                (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
739                (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
740                (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
741                (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
742                    match inner {
743                        None => Ok(()),
744                        Some(inner) => {
745                            if let Some(b) = inner.lower.bound {
746                                difference_with_scalar_type(&b.datum(), element_type)
747                                    .map_err(|e| element_type_difference("range", e))?;
748                            }
749                            if let Some(b) = inner.upper.bound {
750                                difference_with_scalar_type(&b.datum(), element_type)
751                                    .map_err(|e| element_type_difference("range", e))?;
752                            }
753                            Ok(())
754                        }
755                    }
756                }
757                (Datum::Range(_), _) => mismatch(datum, scalar_type),
758                (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
759                (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
760                (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
761                (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
762            }
763        }
764    }
765    if column_type.nullable {
766        if let Datum::Null = datum {
767            return Ok(());
768        }
769    }
770    difference_with_scalar_type(datum, &column_type.scalar_type)
771}
772
773fn row_difference_with_column_types<'a>(
774    source: &'a MirRelationExpr,
775    datums: &Vec<Datum<'_>>,
776    column_types: &[ReprColumnType],
777) -> Result<(), TypeError<'a>> {
778    // correct length
779    if datums.len() != column_types.len() {
780        return Err(TypeError::BadConstantRowLen {
781            source,
782            got: datums.len(),
783            expected: column_types.to_vec(),
784        });
785    }
786
787    // correct types
788    let mut mismatches = Vec::new();
789    for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
790        if let Err(e) = datum_difference_with_column_type(d, ty) {
791            mismatches.push((i, e));
792        }
793    }
794    if !mismatches.is_empty() {
795        return Err(TypeError::BadConstantRow {
796            source,
797            mismatches,
798            expected: column_types.to_vec(),
799        });
800    }
801
802    Ok(())
803}
804/// Check that the visible type of each query has not been changed
805#[derive(Debug)]
806pub struct Typecheck {
807    /// The known types of the queries so far
808    ctx: SharedContext,
809    /// Whether or not this is the first run of the transform
810    disallow_new_globals: bool,
811    /// Whether or not to be strict about join equivalences having the same nullability
812    strict_join_equivalences: bool,
813    /// Whether or not to disallow dummy values
814    disallow_dummy: bool,
815    /// Recursion guard for checked recursion
816    recursion_guard: RecursionGuard,
817}
818
819impl CheckedRecursion for Typecheck {
820    fn recursion_guard(&self) -> &RecursionGuard {
821        &self.recursion_guard
822    }
823}
824
825impl Typecheck {
826    /// Creates a typechecking consistency checking pass using a given shared context
827    pub fn new(ctx: SharedContext) -> Self {
828        Self {
829            ctx,
830            disallow_new_globals: false,
831            strict_join_equivalences: false,
832            disallow_dummy: false,
833            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
834        }
835    }
836
837    /// New non-transient global IDs will be treated as an error
838    ///
839    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
840    pub fn disallow_new_globals(mut self) -> Self {
841        self.disallow_new_globals = true;
842        self
843    }
844
845    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
846    ///
847    /// Only turn this on before `JoinImplementation`
848    pub fn strict_join_equivalences(mut self) -> Self {
849        self.strict_join_equivalences = true;
850
851        self
852    }
853
854    /// Disallow dummy values
855    pub fn disallow_dummy(mut self) -> Self {
856        self.disallow_dummy = true;
857        self
858    }
859
860    /// Returns the type of a relation expression or a type error.
861    ///
862    /// This function is careful to check validity, not just find out the type.
863    ///
864    /// It should be linear in the size of the AST.
865    ///
866    /// ??? should we also compute keys and return a `ReprRelationType`?
867    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
868    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
869    ///   reliably keep it stable across transformations.
870    pub fn typecheck<'a>(
871        &self,
872        expr: &'a MirRelationExpr,
873        ctx: &Context,
874    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
875        use MirRelationExpr::*;
876
877        self.checked_recur(|tc| match expr {
878            Constant { typ, rows } => {
879                if let Ok(rows) = rows {
880                    for (row, _id) in rows {
881                        let datums = row.unpack();
882
883                        row_difference_with_column_types(expr, &datums, &typ.column_types.iter().map(ReprColumnType::from).collect_vec())?;
884
885                        if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
886                            return Err(TypeError::DisallowedDummy {
887                                source: expr,
888                            });
889                        }
890                    }
891                }
892
893                Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
894            }
895            Get { typ, id, .. } => {
896                if let Id::Global(_global_id) = id {
897                    if !ctx.contains_key(id) {
898                        // TODO(mgree) pass QueryContext through to check these types
899                        return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
900                    }
901                }
902
903                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
904                    source: expr,
905                    id: id.clone(),
906                    typ: ReprRelationType::from(typ),
907                })?;
908
909                let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
910
911                // covariant: the ascribed type must be a subtype of the actual type in the context
912                let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
913
914                if !diffs.is_empty() {
915                    return Err(TypeError::MismatchColumns {
916                        source: expr,
917                        got: column_types,
918                        expected: ctx_typ.clone(),
919                        diffs,
920                        message: "annotation did not match context type".to_string(),
921                    });
922                }
923
924                Ok(column_types)
925            }
926            Project { input, outputs } => {
927                let t_in = tc.typecheck(input, ctx)?;
928
929                for x in outputs {
930                    if *x >= t_in.len() {
931                        return Err(TypeError::BadProject {
932                            source: expr,
933                            got: outputs.clone(),
934                            input_type: t_in,
935                        });
936                    }
937                }
938
939                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
940            }
941            Map { input, scalars } => {
942                let mut t_in = tc.typecheck(input, ctx)?;
943
944                for scalar_expr in scalars.iter() {
945                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
946
947                    if self.disallow_dummy && scalar_expr.contains_dummy() {
948                        return Err(TypeError::DisallowedDummy {
949                            source: expr,
950                        });
951                    }
952                }
953
954                Ok(t_in)
955            }
956            FlatMap { input, func, exprs } => {
957                let mut t_in = tc.typecheck(input, ctx)?;
958
959                let mut t_exprs = Vec::with_capacity(exprs.len());
960                for scalar_expr in exprs {
961                    t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
962
963                    if self.disallow_dummy && scalar_expr.contains_dummy() {
964                        return Err(TypeError::DisallowedDummy {
965                            source: expr,
966                        });
967                    }
968                }
969                // TODO(mgree) check t_exprs agrees with `func`'s input type
970
971                let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
972
973                // FlatMap extends the existing columns
974                t_in.extend(t_out);
975                Ok(t_in)
976            }
977            Filter { input, predicates } => {
978                let mut t_in = tc.typecheck(input, ctx)?;
979
980                // Set as nonnull any columns where null values would cause
981                // any predicate to evaluate to null.
982                for column in non_nullable_columns(predicates) {
983                    t_in[column].nullable = false;
984                }
985
986                for scalar_expr in predicates {
987                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
988
989                    // filter condition must be boolean
990                    // ignoring nullability: null is treated as false
991                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
992                    if t.scalar_type != ReprScalarType::Bool {
993                        let sub = t.scalar_type.clone();
994
995                        return Err(TypeError::MismatchColumn {
996                            source: expr,
997                            got: t,
998                            expected: ReprColumnType {
999                                scalar_type: ReprScalarType::Bool,
1000                                nullable: true,
1001                            },
1002                            diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
1003                            message: "expected boolean condition".to_string(),
1004                        });
1005                    }
1006
1007                    if self.disallow_dummy && scalar_expr.contains_dummy() {
1008                        return Err(TypeError::DisallowedDummy {
1009                            source: expr,
1010                        });
1011                    }
1012                }
1013
1014                Ok(t_in)
1015            }
1016            Join {
1017                inputs,
1018                equivalences,
1019                implementation,
1020            } => {
1021                let mut t_in_global = Vec::new();
1022                let mut t_in_local = vec![Vec::new(); inputs.len()];
1023
1024                for (i, input) in inputs.iter().enumerate() {
1025                    let input_t = tc.typecheck(input, ctx)?;
1026                    t_in_global.extend(input_t.clone());
1027                    t_in_local[i] = input_t;
1028                }
1029
1030                for eq_class in equivalences {
1031                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1032
1033                    let mut all_nullable = true;
1034
1035                    for scalar_expr in eq_class {
1036                        // Note: the equivalences have global column references
1037                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1038
1039                        if !t_expr.nullable {
1040                            all_nullable = false;
1041                        }
1042
1043                        if let Some(t_first) = t_exprs.get(0) {
1044                            let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
1045                            if !diffs.is_empty() {
1046                                return Err(TypeError::MismatchColumn {
1047                                    source: expr,
1048                                    got: t_expr,
1049                                    expected: t_first.clone(),
1050                                    diffs,
1051                                    message: "equivalence class members have different scalar types".to_string(),
1052                                });
1053                            }
1054
1055                            // equivalences may or may not match on nullability
1056                            // before JoinImplementation runs, nullability should match.
1057                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
1058                            if self.strict_join_equivalences {
1059                                if t_expr.nullable != t_first.nullable {
1060                                    let sub = t_expr.clone();
1061                                    let sup = t_first.clone();
1062
1063                                    let err = TypeError::MismatchColumn {
1064                                        source: expr,
1065                                        got: t_expr.clone(),
1066                                        expected: t_first.clone(),
1067                                        diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
1068                                        message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
1069                                    };
1070
1071                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1072                                    ::tracing::debug!("{err}");
1073                                }
1074                            }
1075                        }
1076
1077                        if self.disallow_dummy && scalar_expr.contains_dummy() {
1078                            return Err(TypeError::DisallowedDummy {
1079                                source: expr,
1080                            });
1081                        }
1082
1083                        t_exprs.push(t_expr);
1084                    }
1085
1086                    if self.strict_join_equivalences && all_nullable {
1087                        let err = TypeError::BadJoinEquivalence {
1088                            source: expr,
1089                            got: t_exprs,
1090                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1091                        };
1092
1093                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1094                        ::tracing::debug!("{err}");
1095                    }
1096                }
1097
1098                // check that the join implementation is consistent
1099                match implementation {
1100                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
1101                        if let Some(key) = first_key {
1102                            for k in key {
1103                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1104                            }
1105                        }
1106
1107                        for (idx, key, _) in others {
1108                            for k in key {
1109                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1110                            }
1111                        }
1112                    }
1113                    JoinImplementation::DeltaQuery(plans) => {
1114                        for plan in plans {
1115                            for (idx, key, _) in plan {
1116                                for k in key {
1117                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1118                                }
1119                            }
1120                        }
1121                    }
1122                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1123                        let typ: Vec<ReprColumnType> = key
1124                            .iter()
1125                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1126                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1127
1128                        for row in consts {
1129                            let datums = row.unpack();
1130
1131                            row_difference_with_column_types(expr, &datums, &typ)?;
1132                        }
1133                    }
1134                    JoinImplementation::Unimplemented => (),
1135                }
1136
1137                Ok(t_in_global)
1138            }
1139            Reduce {
1140                input,
1141                group_key,
1142                aggregates,
1143                monotonic: _,
1144                expected_group_size: _,
1145            } => {
1146                let t_in = tc.typecheck(input, ctx)?;
1147
1148                let mut t_out = group_key
1149                    .iter()
1150                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1151                    .collect::<Result<Vec<_>, _>>()?;
1152
1153                    if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
1154                        return Err(TypeError::DisallowedDummy {
1155                            source: expr,
1156                        });
1157                    }
1158
1159                for agg in aggregates {
1160                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1161                }
1162
1163                Ok(t_out)
1164            }
1165            TopK {
1166                input,
1167                group_key,
1168                order_key,
1169                limit: _,
1170                offset: _,
1171                monotonic: _,
1172                expected_group_size: _,
1173            } => {
1174                let t_in = tc.typecheck(input, ctx)?;
1175
1176                for &k in group_key {
1177                    if k >= t_in.len() {
1178                        return Err(TypeError::BadTopKGroupKey {
1179                            source: expr,
1180                            k,
1181                            input_type: t_in,
1182                        });
1183                    }
1184                }
1185
1186                for order in order_key {
1187                    if order.column >= t_in.len() {
1188                        return Err(TypeError::BadTopKOrdering {
1189                            source: expr,
1190                            order: order.clone(),
1191                            input_type: t_in,
1192                        });
1193                    }
1194                }
1195
1196                Ok(t_in)
1197            }
1198            Negate { input } => tc.typecheck(input, ctx),
1199            Threshold { input } => tc.typecheck(input, ctx),
1200            Union { base, inputs } => {
1201                let mut t_base = tc.typecheck(base, ctx)?;
1202
1203                for input in inputs {
1204                    let t_input = tc.typecheck(input, ctx)?;
1205
1206                    let len_sub = t_base.len();
1207                    let len_sup = t_input.len();
1208                    if len_sub != len_sup {
1209                        return Err(TypeError::MismatchColumns {
1210                            source: expr,
1211                            got: t_base.clone(),
1212                            expected: t_input,
1213                            diffs: vec![ReprRelationTypeDifference::Length {
1214                                len_sub,
1215                                len_sup,
1216                            }],
1217                            message: "Union branches have different numbers of columns".to_string(),
1218                        });
1219                    }
1220
1221                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1222                        let diffs = column_union(base_col, &input_col);
1223                        if !diffs.is_empty() {
1224                            return Err(TypeError::MismatchColumn {
1225                                    source: expr,
1226                                    got: input_col,
1227                                    expected: base_col.clone(),
1228                                    diffs,
1229                                    message:
1230                                        "couldn't compute union of column types in Union"
1231                                    .to_string(),
1232                            });
1233                        }
1234
1235                    }
1236                }
1237
1238                Ok(t_base)
1239            }
1240            Let { id, value, body } => {
1241                let t_value = tc.typecheck(value, ctx)?;
1242
1243                let binding = Id::Local(*id);
1244                if ctx.contains_key(&binding) {
1245                    return Err(TypeError::Shadowing {
1246                        source: expr,
1247                        id: binding,
1248                    });
1249                }
1250
1251                let mut body_ctx = ctx.clone();
1252                body_ctx.insert(Id::Local(*id), t_value);
1253
1254                tc.typecheck(body, &body_ctx)
1255            }
1256            LetRec { ids, values, body, limits: _ } => {
1257                if ids.len() != values.len() {
1258                    return Err(TypeError::BadLetRecBindings { source: expr });
1259                }
1260
1261                // temporary hack: steal info from the Gets inside to learn the expected types
1262                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1263                let mut ctx = ctx.clone();
1264                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1265                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1266                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1267                }
1268
1269                for (id, value) in ids.iter().zip_eq(values.iter()) {
1270                    let typ = tc.typecheck(value, &ctx)?;
1271
1272                    let id = Id::Local(id.clone());
1273                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1274                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1275                            // we expect an EXACT match, but don't care about nullability
1276                            let diffs = column_union(base_col, &input_col);
1277                            if !diffs.is_empty() {
1278                                 return Err(TypeError::MismatchColumn {
1279                                        source: expr,
1280                                        got: input_col,
1281                                        expected: base_col.clone(),
1282                                        diffs,
1283                                        message:
1284                                            "couldn't compute union of column types in LetRec"
1285                                        .to_string(),
1286                                    })
1287                            }
1288                        }
1289                    } else {
1290                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1291                        ctx.insert(id, typ);
1292                    }
1293                }
1294
1295                tc.typecheck(body, &ctx)
1296            }
1297            ArrangeBy { input, keys } => {
1298                let t_in = tc.typecheck(input, ctx)?;
1299
1300                for key in keys {
1301                    for k in key {
1302                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1303                    }
1304                }
1305
1306                Ok(t_in)
1307            }
1308        })
1309    }
1310
1311    /// Traverses a term to collect the types of given ids.
1312    ///
1313    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1314    fn collect_recursive_variable_types<'a>(
1315        &self,
1316        expr: &'a MirRelationExpr,
1317        ids: &[LocalId],
1318        ctx: &mut Context,
1319    ) -> Result<(), TypeError<'a>> {
1320        use MirRelationExpr::*;
1321
1322        self.checked_recur(|tc| {
1323            match expr {
1324                Get {
1325                    id: Id::Local(id),
1326                    typ,
1327                    ..
1328                } => {
1329                    if !ids.contains(id) {
1330                        return Ok(());
1331                    }
1332
1333                    let id = Id::Local(id.clone());
1334                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1335                        let typ = typ
1336                            .column_types
1337                            .iter()
1338                            .map(ReprColumnType::from)
1339                            .collect_vec();
1340
1341                        if ctx_typ.len() != typ.len() {
1342                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1343
1344                            return Err(TypeError::MismatchColumns {
1345                                source: expr,
1346                                got: typ,
1347                                expected: ctx_typ.clone(),
1348                                diffs,
1349                                message: "environment and type annotation did not match"
1350                                    .to_string(),
1351                            });
1352                        }
1353
1354                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1355                            let diffs = column_union(base_col, &input_col);
1356                            if !diffs.is_empty() {
1357                                return Err(TypeError::MismatchColumn {
1358                                    source: expr,
1359                                    got: input_col,
1360                                    expected: base_col.clone(),
1361                                    diffs,
1362                                    message:
1363                                        "couldn't compute union of column types in Get and context"
1364                                            .to_string(),
1365                                });
1366                            }
1367                        }
1368                    } else {
1369                        ctx.insert(
1370                            id,
1371                            typ.column_types
1372                                .iter()
1373                                .map(ReprColumnType::from)
1374                                .collect_vec(),
1375                        );
1376                    }
1377                }
1378                Get {
1379                    id: Id::Global(..), ..
1380                }
1381                | Constant { .. } => (),
1382                Let { id, value, body } => {
1383                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1384
1385                    // we've shadowed the id
1386                    if ids.contains(id) {
1387                        return Err(TypeError::Shadowing {
1388                            source: expr,
1389                            id: Id::Local(*id),
1390                        });
1391                    }
1392
1393                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1394                }
1395                LetRec {
1396                    ids: inner_ids,
1397                    values,
1398                    body,
1399                    limits: _,
1400                } => {
1401                    for inner_id in inner_ids {
1402                        if ids.contains(inner_id) {
1403                            return Err(TypeError::Shadowing {
1404                                source: expr,
1405                                id: Id::Local(*inner_id),
1406                            });
1407                        }
1408                    }
1409
1410                    for value in values {
1411                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1412                    }
1413
1414                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1415                }
1416                Project { input, .. }
1417                | Map { input, .. }
1418                | FlatMap { input, .. }
1419                | Filter { input, .. }
1420                | Reduce { input, .. }
1421                | TopK { input, .. }
1422                | Negate { input }
1423                | Threshold { input }
1424                | ArrangeBy { input, .. } => {
1425                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1426                }
1427                Join { inputs, .. } => {
1428                    for input in inputs {
1429                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1430                    }
1431                }
1432                Union { base, inputs } => {
1433                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1434
1435                    for input in inputs {
1436                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1437                    }
1438                }
1439            }
1440
1441            Ok(())
1442        })
1443    }
1444
1445    fn typecheck_scalar<'a>(
1446        &self,
1447        expr: &'a MirScalarExpr,
1448        source: &'a MirRelationExpr,
1449        column_types: &[ReprColumnType],
1450    ) -> Result<ReprColumnType, TypeError<'a>> {
1451        use MirScalarExpr::*;
1452
1453        self.checked_recur(|tc| match expr {
1454            Column(i, _) => match column_types.get(*i) {
1455                Some(ty) => Ok(ty.clone()),
1456                None => Err(TypeError::NoSuchColumn {
1457                    source,
1458                    expr,
1459                    col: *i,
1460                }),
1461            },
1462            Literal(row, typ) => {
1463                let typ = ReprColumnType::from(typ);
1464                if let Ok(row) = row {
1465                    let datums = row.unpack();
1466
1467                    row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1468                }
1469
1470                Ok(typ)
1471            }
1472            CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1473            CallUnary { expr, func } => {
1474                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1475                let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1476                Ok(ReprColumnType::from(&typ_out))
1477            }
1478            CallBinary { expr1, expr2, func } => {
1479                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1480                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1481                let typ_out = func.output_type(
1482                    SqlColumnType::from_repr(&typ_in1),
1483                    SqlColumnType::from_repr(&typ_in2),
1484                );
1485                Ok(ReprColumnType::from(&typ_out))
1486            }
1487            CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1488                &func.output_type(
1489                    exprs
1490                        .iter()
1491                        .map(|e| {
1492                            tc.typecheck_scalar(e, source, column_types)
1493                                .map(|typ| SqlColumnType::from_repr(&typ))
1494                        })
1495                        .collect::<Result<Vec<_>, TypeError>>()?,
1496                ),
1497            )),
1498            If { cond, then, els } => {
1499                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1500
1501                // condition must be boolean
1502                // ignoring nullability: null is treated as false
1503                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1504                if cond_type.scalar_type != ReprScalarType::Bool {
1505                    let sub = cond_type.scalar_type.clone();
1506
1507                    return Err(TypeError::MismatchColumn {
1508                        source,
1509                        got: cond_type,
1510                        expected: ReprColumnType {
1511                            scalar_type: ReprScalarType::Bool,
1512                            nullable: true,
1513                        },
1514                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1515                            sub,
1516                            sup: ReprScalarType::Bool,
1517                        }],
1518                        message: "expected boolean condition".to_string(),
1519                    });
1520                }
1521
1522                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1523                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1524
1525                let diffs = column_union(&mut then_type, &else_type);
1526                if !diffs.is_empty() {
1527                    return Err(TypeError::MismatchColumn {
1528                        source,
1529                        got: then_type,
1530                        expected: else_type,
1531                        diffs,
1532                        message: "couldn't compute union of column types for If".to_string(),
1533                    });
1534                }
1535
1536                Ok(then_type)
1537            }
1538        })
1539    }
1540
1541    /// Typecheck an `AggregateExpr`
1542    pub fn typecheck_aggregate<'a>(
1543        &self,
1544        expr: &'a AggregateExpr,
1545        source: &'a MirRelationExpr,
1546        column_types: &[ReprColumnType],
1547    ) -> Result<ReprColumnType, TypeError<'a>> {
1548        self.checked_recur(|tc| {
1549            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1550
1551            // TODO check that t_in is actually acceptable for `func`
1552
1553            Ok(ReprColumnType::from(
1554                &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1555            ))
1556        })
1557    }
1558}
1559
1560/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1561///
1562/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1563macro_rules! type_error {
1564    ($severity:expr, $($arg:tt)+) => {{
1565        if $severity {
1566          soft_panic_or_log!($($arg)+);
1567        } else {
1568          ::tracing::debug!($($arg)+);
1569        }
1570    }}
1571}
1572
1573impl crate::Transform for Typecheck {
1574    fn name(&self) -> &'static str {
1575        "Typecheck"
1576    }
1577
1578    fn actually_perform_transform(
1579        &self,
1580        relation: &mut MirRelationExpr,
1581        transform_ctx: &mut crate::TransformCtx,
1582    ) -> Result<(), crate::TransformError> {
1583        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1584
1585        let expected = transform_ctx
1586            .global_id
1587            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1588
1589        if let Some(id) = transform_ctx.global_id {
1590            if self.disallow_new_globals
1591                && expected.is_none()
1592                && transform_ctx.global_id.is_some()
1593                && !id.is_transient()
1594            {
1595                type_error!(
1596                    false, // not severe
1597                    "type warning: new non-transient global id {id}\n{}",
1598                    relation.pretty()
1599                );
1600            }
1601        }
1602
1603        let got = self.typecheck(relation, &typecheck_ctx);
1604
1605        let humanizer = mz_repr::explain::DummyHumanizer;
1606
1607        match (got, expected) {
1608            (Ok(got), Some(expected)) => {
1609                let id = transform_ctx.global_id.unwrap();
1610
1611                // contravariant: global types can be updated
1612                let diffs = relation_subtype_difference(expected, &got);
1613                if !diffs.is_empty() {
1614                    // SEVERE only if got and expected have true differences, not just nullability
1615                    let severity = diffs
1616                        .iter()
1617                        .any(|diff| diff.clone().ignore_nullability().is_some());
1618
1619                    let err = TypeError::MismatchColumns {
1620                        source: relation,
1621                        got,
1622                        expected: expected.clone(),
1623                        diffs,
1624                        message: format!(
1625                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1626                        ),
1627                    };
1628
1629                    type_error!(severity, "type error in known global id {id}:\n{err}");
1630                }
1631            }
1632            (Ok(got), None) => {
1633                if let Some(id) = transform_ctx.global_id {
1634                    typecheck_ctx.insert(Id::Global(id), got);
1635                }
1636            }
1637            (Err(err), _) => {
1638                let (expected, binding) = match expected {
1639                    Some(expected) => {
1640                        let id = transform_ctx.global_id.unwrap();
1641                        (
1642                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1643                            format!("known global id {id}"),
1644                        )
1645                    }
1646                    None => ("".to_string(), "transient query".to_string()),
1647                };
1648
1649                type_error!(
1650                    true, // SEVERE: the transformed code is inconsistent
1651                    "type error in {binding}:\n{err}\n{expected}{}",
1652                    relation.pretty()
1653                );
1654            }
1655        }
1656
1657        Ok(())
1658    }
1659}
1660
1661/// Prints a type prettily with a given `ExprHumanizer`
1662pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1663where
1664    H: ExprHumanizer,
1665{
1666    let mut s = String::with_capacity(2 + 3 * cols.len());
1667
1668    s.push('(');
1669
1670    let mut it = cols.iter().peekable();
1671    while let Some(col) = it.next() {
1672        s.push_str(&humanizer.humanize_column_type_repr(col, false));
1673
1674        if it.peek().is_some() {
1675            s.push_str(", ");
1676        }
1677    }
1678
1679    s.push(')');
1680
1681    s
1682}
1683
1684impl ReprRelationTypeDifference {
1685    /// Pretty prints a type difference
1686    ///
1687    /// Always indents two spaces
1688    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1689    where
1690        H: ExprHumanizer,
1691    {
1692        use ReprRelationTypeDifference::*;
1693        match self {
1694            Length { len_sub, len_sup } => {
1695                writeln!(
1696                    f,
1697                    "  number of columns do not match ({len_sub} != {len_sup})"
1698                )
1699            }
1700            Column { col, diff } => {
1701                writeln!(f, "  column {col} differs:")?;
1702                diff.humanize(4, h, f)
1703            }
1704        }
1705    }
1706}
1707
1708impl ReprColumnTypeDifference {
1709    /// Pretty prints a type difference at a given indentation level
1710    pub fn humanize<H>(
1711        &self,
1712        indent: usize,
1713        h: &H,
1714        f: &mut std::fmt::Formatter<'_>,
1715    ) -> std::fmt::Result
1716    where
1717        H: ExprHumanizer,
1718    {
1719        use ReprColumnTypeDifference::*;
1720
1721        // indent
1722        write!(f, "{:indent$}", "")?;
1723
1724        match self {
1725            NotSubtype { sub, sup } => {
1726                let sub = h.humanize_scalar_type_repr(sub, false);
1727                let sup = h.humanize_scalar_type_repr(sup, false);
1728
1729                writeln!(f, "{sub} is a not a subtype of {sup}")
1730            }
1731            Nullability { sub, sup } => {
1732                let sub = h.humanize_column_type_repr(sub, false);
1733                let sup = h.humanize_column_type_repr(sup, false);
1734
1735                writeln!(f, "{sub} is nullable but {sup} is not")
1736            }
1737            ElementType { ctor, element_type } => {
1738                writeln!(f, "{ctor} element types differ:")?;
1739
1740                element_type.humanize(indent + 2, h, f)
1741            }
1742            RecordMissingFields { missing } => {
1743                write!(f, "missing column fields:")?;
1744                for col in missing {
1745                    write!(f, " {col}")?;
1746                }
1747                f.write_char('\n')
1748            }
1749            RecordFields { fields } => {
1750                writeln!(f, "{} record fields differ:", fields.len())?;
1751
1752                for (i, diff) in fields.iter().enumerate() {
1753                    writeln!(f, "{:indent$}  field {i}:", "")?;
1754                    diff.humanize(indent + 4, h, f)?;
1755                }
1756                Ok(())
1757            }
1758        }
1759    }
1760}
1761
1762impl DatumTypeDifference {
1763    /// Pretty prints a type difference at a given indentation level
1764    pub fn humanize<H>(
1765        &self,
1766        indent: usize,
1767        h: &H,
1768        f: &mut std::fmt::Formatter<'_>,
1769    ) -> std::fmt::Result
1770    where
1771        H: ExprHumanizer,
1772    {
1773        // indent
1774        write!(f, "{:indent$}", "")?;
1775
1776        match self {
1777            DatumTypeDifference::Null => writeln!(f, "unexpected null")?,
1778            DatumTypeDifference::Mismatch {
1779                got_debug,
1780                expected,
1781            } => {
1782                let expected = h.humanize_scalar_type_repr(expected, false);
1783                // NB `got_debug` will be redacted as appropriate
1784                writeln!(
1785                    f,
1786                    "got datum {got_debug}, expected representation type {expected}"
1787                )?;
1788            }
1789            DatumTypeDifference::MismatchDimensions {
1790                ctor,
1791                got,
1792                expected,
1793            } => {
1794                writeln!(
1795                    f,
1796                    "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1797                )?;
1798            }
1799            DatumTypeDifference::ElementType { ctor, element_type } => {
1800                writeln!(f, "{ctor} element types differ:")?;
1801                element_type.humanize(indent + 4, h, f)?;
1802            }
1803        }
1804
1805        Ok(())
1806    }
1807}
1808
1809/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1810#[allow(missing_debug_implementations)]
1811pub struct TypeErrorHumanizer<'a, 'b, H>
1812where
1813    H: ExprHumanizer,
1814{
1815    err: &'a TypeError<'a>,
1816    humanizer: &'b H,
1817}
1818
1819impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1820where
1821    H: ExprHumanizer,
1822{
1823    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1824    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1825        Self { err, humanizer }
1826    }
1827}
1828
1829impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1830where
1831    H: ExprHumanizer,
1832{
1833    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1834        self.err.humanize(self.humanizer, f)
1835    }
1836}
1837
1838impl<'a> std::fmt::Display for TypeError<'a> {
1839    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840        TypeErrorHumanizer {
1841            err: self,
1842            humanizer: &DummyHumanizer,
1843        }
1844        .fmt(f)
1845    }
1846}
1847
1848impl<'a> TypeError<'a> {
1849    /// The source of the type error
1850    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1851        use TypeError::*;
1852        match self {
1853            Unbound { source, .. }
1854            | NoSuchColumn { source, .. }
1855            | MismatchColumn { source, .. }
1856            | MismatchColumns { source, .. }
1857            | BadConstantRowLen { source, .. }
1858            | BadConstantRow { source, .. }
1859            | BadProject { source, .. }
1860            | BadJoinEquivalence { source, .. }
1861            | BadTopKGroupKey { source, .. }
1862            | BadTopKOrdering { source, .. }
1863            | BadLetRecBindings { source }
1864            | Shadowing { source, .. }
1865            | DisallowedDummy { source, .. } => Some(source),
1866            Recursion { .. } => None,
1867        }
1868    }
1869
1870    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1871    where
1872        H: ExprHumanizer,
1873    {
1874        if let Some(source) = self.source() {
1875            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1876        }
1877
1878        use TypeError::*;
1879        match self {
1880            Unbound { source: _, id, typ } => {
1881                let typ = columns_pretty(&typ.column_types, humanizer);
1882                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1883            }
1884            NoSuchColumn {
1885                source: _,
1886                expr,
1887                col,
1888            } => writeln!(f, "{expr} references non-existent column {col}")?,
1889            MismatchColumn {
1890                source: _,
1891                got,
1892                expected,
1893                diffs,
1894                message,
1895            } => {
1896                let got = humanizer.humanize_column_type_repr(got, false);
1897                let expected = humanizer.humanize_column_type_repr(expected, false);
1898                writeln!(
1899                    f,
1900                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1901                )?;
1902
1903                for diff in diffs {
1904                    diff.humanize(2, humanizer, f)?;
1905                }
1906            }
1907            MismatchColumns {
1908                source: _,
1909                got,
1910                expected,
1911                diffs,
1912                message,
1913            } => {
1914                let got = columns_pretty(got, humanizer);
1915                let expected = columns_pretty(expected, humanizer);
1916
1917                writeln!(
1918                    f,
1919                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1920                )?;
1921
1922                for diff in diffs {
1923                    diff.humanize(humanizer, f)?;
1924                }
1925            }
1926            BadConstantRowLen {
1927                source: _,
1928                got,
1929                expected,
1930            } => {
1931                let expected = columns_pretty(expected, humanizer);
1932                writeln!(
1933                    f,
1934                    "bad constant row\n      row has length {got}\nexpected row of type {expected}"
1935                )?
1936            }
1937            BadConstantRow {
1938                source: _,
1939                mismatches,
1940                expected,
1941            } => {
1942                let expected = columns_pretty(expected, humanizer);
1943
1944                let num_mismatches = mismatches.len();
1945                let plural = if num_mismatches == 1 { "" } else { "es" };
1946                writeln!(
1947                    f,
1948                    "bad constant row\n      got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1949                )?;
1950
1951                if num_mismatches > 0 {
1952                    writeln!(f, "")?;
1953                    for (col, diff) in mismatches.iter() {
1954                        writeln!(f, "      column #{col}:")?;
1955                        diff.humanize(8, humanizer, f)?;
1956                    }
1957                }
1958            }
1959            BadProject {
1960                source: _,
1961                got,
1962                input_type,
1963            } => {
1964                let input_type = columns_pretty(input_type, humanizer);
1965
1966                writeln!(
1967                    f,
1968                    "projection of non-existant columns {got:?} from type {input_type}"
1969                )?
1970            }
1971            BadJoinEquivalence {
1972                source: _,
1973                got,
1974                message,
1975            } => {
1976                let got = columns_pretty(got, humanizer);
1977
1978                writeln!(f, "bad join equivalence {got}: {message}")?
1979            }
1980            BadTopKGroupKey {
1981                source: _,
1982                k,
1983                input_type,
1984            } => {
1985                let input_type = columns_pretty(input_type, humanizer);
1986
1987                writeln!(
1988                    f,
1989                    "TopK group key component references invalid column {k} in columns: {input_type}"
1990                )?
1991            }
1992            BadTopKOrdering {
1993                source: _,
1994                order,
1995                input_type,
1996            } => {
1997                let col = order.column;
1998                let num_cols = input_type.len();
1999                let are = if num_cols == 1 { "is" } else { "are" };
2000                let s = if num_cols == 1 { "" } else { "s" };
2001                let input_type = columns_pretty(input_type, humanizer);
2002
2003                // TODO(cloud#8196)
2004                let mode = HumanizedExplain::new(false);
2005                let order = mode.expr(order, None);
2006
2007                writeln!(
2008                    f,
2009                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2010                )?
2011            }
2012            BadLetRecBindings { source: _ } => {
2013                writeln!(f, "LetRec ids and definitions don't line up")?
2014            }
2015            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2016            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2017            Recursion { error } => writeln!(f, "{error}")?,
2018        }
2019
2020        Ok(())
2021    }
2022}
2023
2024#[cfg(test)]
2025mod tests {
2026    use mz_ore::{assert_err, assert_ok};
2027    use mz_repr::{arb_datum, arb_datum_for_column};
2028    use proptest::prelude::*;
2029
2030    use super::*;
2031
2032    #[mz_ore::test]
2033    fn test_datum_type_difference() {
2034        let datum = Datum::Int16(1);
2035
2036        assert_ok!(datum_difference_with_column_type(
2037            &datum,
2038            &ReprColumnType {
2039                scalar_type: ReprScalarType::Int16,
2040                nullable: true,
2041            }
2042        ));
2043
2044        assert_err!(datum_difference_with_column_type(
2045            &datum,
2046            &ReprColumnType {
2047                scalar_type: ReprScalarType::Int32,
2048                nullable: false,
2049            }
2050        ));
2051    }
2052
2053    proptest! {
2054        #![proptest_config(ProptestConfig { cases: 5000, max_global_rejects: 2500, ..Default::default() })]
2055        #[mz_ore::test]
2056        #[cfg_attr(miri, ignore)]
2057        fn datum_type_difference_with_instance_of_on_valid_data((src, datum) in any::<SqlColumnType>().prop_flat_map(|src| {
2058            let datum = arb_datum_for_column(src.clone());
2059            (Just(src), datum) }
2060        )) {
2061            let typ = ReprColumnType::from(&src);
2062            let datum = Datum::from(&datum);
2063
2064            if datum.contains_dummy() {
2065                return Err(TestCaseError::reject("datum contains a dummy"));
2066            }
2067
2068            let diff = datum_difference_with_column_type(&datum, &typ);
2069            if datum.is_instance_of(&typ) {
2070                assert_ok!(diff);
2071            } else {
2072                assert_err!(diff);
2073            }
2074        }
2075    }
2076
2077    proptest! {
2078        // We run many cases because the data are _random_, and we want to be sure
2079        // that we have covered sufficient cases. We drop dummy data (behavior is different!) so we allow many more global rejects.
2080        #![proptest_config(ProptestConfig::with_cases(10000))]
2081        #[mz_ore::test]
2082        #[cfg_attr(miri, ignore)]
2083        fn datum_type_difference_agrees_with_is_instance_of_on_random_data(src in any::<SqlColumnType>(), datum in arb_datum(false)) {
2084            let typ = ReprColumnType::from(&src);
2085            let datum = Datum::from(&datum);
2086
2087            assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2088
2089            let diff = datum_difference_with_column_type(&datum, &typ);
2090            if datum.is_instance_of(&typ) {
2091                assert_ok!(diff);
2092            } else {
2093                assert_err!(diff);
2094            }
2095        }
2096    }
2097}