Skip to main content

mz_transform/
reprtypecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27    ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28    SqlColumnType,
29};
30
31/// Typechecking contexts as shared by various typechecking passes.
32///
33/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
34/// Shared contexts help catch consistency issues.
35pub type SharedContext = Arc<Mutex<Context>>;
36
37/// Generates an empty context
38pub fn empty_context() -> SharedContext {
39    Arc::new(Mutex::new(BTreeMap::new()))
40}
41
42/// The possible forms of inconsistency/errors discovered during typechecking.
43///
44/// Every variant has a `source` field identifying the MIR term that is home
45/// to the error (though not necessarily the root cause of the error).
46#[derive(Clone, Debug)]
47pub enum TypeError<'a> {
48    /// Unbound identifiers (local or global)
49    Unbound {
50        /// Expression with the bug
51        source: &'a MirRelationExpr,
52        /// The (unbound) identifier referenced
53        id: Id,
54        /// The type `id` was expected to have
55        typ: ReprRelationType,
56    },
57    /// Dereference of a non-existent column
58    NoSuchColumn {
59        /// Expression with the bug
60        source: &'a MirRelationExpr,
61        /// Scalar expression that references an invalid column
62        expr: &'a MirScalarExpr,
63        /// The invalid column referenced
64        col: usize,
65    },
66    /// A single column type does not match
67    MismatchColumn {
68        /// Expression with the bug
69        source: &'a MirRelationExpr,
70        /// The column type we found (`sub` type)
71        got: ReprColumnType,
72        /// The column type we expected (`sup` type)
73        expected: ReprColumnType,
74        /// The difference between these types
75        diffs: Vec<ReprColumnTypeDifference>,
76        /// An explanatory message
77        message: String,
78    },
79    /// Relation column types do not match
80    MismatchColumns {
81        /// Expression with the bug
82        source: &'a MirRelationExpr,
83        /// The column types we found (`sub` type)
84        got: Vec<ReprColumnType>,
85        /// The column types we expected (`sup` type)
86        expected: Vec<ReprColumnType>,
87        /// The difference between these types
88        diffs: Vec<ReprRelationTypeDifference>,
89        /// An explanatory message
90        message: String,
91    },
92    /// A constant row does not have the correct type
93    BadConstantRowLen {
94        /// Expression with the bug
95        source: &'a MirRelationExpr,
96        /// A constant row
97        got: usize,
98        /// The expected type (which that row does not have)
99        expected: Vec<ReprColumnType>,
100    },
101    /// A constant row does not have the correct type
102    BadConstantRow {
103        /// Expression with the bug
104        source: &'a MirRelationExpr,
105        /// The columns that mismatched, with a debug rendering of the datum we got and the expected type
106        mismatches: Vec<(usize, DatumTypeDifference)>,
107        /// The expected type (which that row does not have)
108        expected: Vec<ReprColumnType>,
109        // TODO(mgree) with a good way to get the type of a Datum, we could give a detailed diff here
110        // this is difficult for empty structures where we may not know the element type
111    },
112    /// Projection of a non-existent column
113    BadProject {
114        /// Expression with the bug
115        source: &'a MirRelationExpr,
116        /// The column projected
117        got: Vec<usize>,
118        /// The input columns (which don't have that column)
119        input_type: Vec<ReprColumnType>,
120    },
121    /// An equivalence class in a join was malformed
122    BadJoinEquivalence {
123        /// Expression with the bug
124        source: &'a MirRelationExpr,
125        /// The join equivalences
126        got: Vec<ReprColumnType>,
127        /// The problem with the join equivalences
128        message: String,
129    },
130    /// TopK grouping by non-existent column
131    BadTopKGroupKey {
132        /// Expression with the bug
133        source: &'a MirRelationExpr,
134        /// The bad column reference in the group key
135        k: usize,
136        /// The input columns (which don't have that column)
137        input_type: Vec<ReprColumnType>,
138    },
139    /// TopK ordering by non-existent column
140    BadTopKOrdering {
141        /// Expression with the bug
142        source: &'a MirRelationExpr,
143        /// The ordering used
144        order: ColumnOrder,
145        /// The input columns (which don't work for that ordering)
146        input_type: Vec<ReprColumnType>,
147    },
148    /// LetRec bindings are malformed
149    BadLetRecBindings {
150        /// Expression with the bug
151        source: &'a MirRelationExpr,
152    },
153    /// Local identifiers are shadowed
154    Shadowing {
155        /// Expression with the bug
156        source: &'a MirRelationExpr,
157        /// The id that was shadowed
158        id: Id,
159    },
160    /// Recursion depth exceeded
161    Recursion {
162        /// The error that aborted recursion
163        error: RecursionLimitError,
164    },
165    /// A dummy value was found
166    DisallowedDummy {
167        /// The expression with the dummy value
168        source: &'a MirRelationExpr,
169    },
170}
171
172impl<'a> From<RecursionLimitError> for TypeError<'a> {
173    fn from(error: RecursionLimitError) -> Self {
174        TypeError::Recursion { error }
175    }
176}
177
178type Context = BTreeMap<Id, Vec<ReprColumnType>>;
179
180/// Characterizes differences between relation types
181///
182/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
183#[derive(Clone, Debug, Hash)]
184pub enum ReprRelationTypeDifference {
185    /// `sub` and `sup` don't have the same number of columns
186    Length {
187        /// Length of `sub`
188        len_sub: usize,
189        /// Length of `sup`
190        len_sup: usize,
191    },
192    /// `sub` and `sup` differ at the indicated column
193    Column {
194        /// The column at which `sub` and `sup` differ
195        col: usize,
196        /// The difference between `sub` and `sup`
197        diff: ReprColumnTypeDifference,
198    },
199}
200
201/// Characterizes differences between individual column types
202///
203/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
204/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
205#[derive(Clone, Debug, Hash)]
206pub enum ReprColumnTypeDifference {
207    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
208    NotSubtype {
209        /// Would-be subtype
210        sub: ReprScalarType,
211        /// Would-be supertype
212        sup: ReprScalarType,
213    },
214    /// `sub` was nullable but `sup` was not
215    Nullability {
216        /// Would-be subtype
217        sub: ReprColumnType,
218        /// Would-be supertype
219        sup: ReprColumnType,
220    },
221    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
222    ElementType {
223        /// The type constructor (list, array, etc.)
224        ctor: String,
225        /// The difference in the element type
226        element_type: Box<ReprColumnTypeDifference>,
227    },
228    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
229    RecordMissingFields {
230        /// The missing fields
231        missing: Vec<ColumnName>,
232    },
233    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
234    RecordFields {
235        /// The differences, by field
236        fields: Vec<ReprColumnTypeDifference>,
237    },
238}
239
240impl ReprRelationTypeDifference {
241    /// Returns the same type difference, but ignoring nullability
242    ///
243    /// Returns `None` when _all_ of the differences are due to nullability
244    pub fn ignore_nullability(self) -> Option<Self> {
245        use ReprRelationTypeDifference::*;
246
247        match self {
248            Length { .. } => Some(self),
249            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
250        }
251    }
252}
253
254impl ReprColumnTypeDifference {
255    /// Returns the same type difference, but ignoring nullability
256    ///
257    /// Returns `None` when _all_ of the differences are due to nullability
258    pub fn ignore_nullability(self) -> Option<Self> {
259        use ReprColumnTypeDifference::*;
260
261        match self {
262            Nullability { .. } => None,
263            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
264            ElementType { ctor, element_type } => {
265                element_type
266                    .ignore_nullability()
267                    .map(|element_type| ElementType {
268                        ctor,
269                        element_type: Box::new(element_type),
270                    })
271            }
272            RecordFields { fields } => {
273                let fields = fields
274                    .into_iter()
275                    .flat_map(|diff| diff.ignore_nullability())
276                    .collect::<Vec<_>>();
277
278                if fields.is_empty() {
279                    None
280                } else {
281                    Some(RecordFields { fields })
282                }
283            }
284        }
285    }
286}
287
288/// Returns a list of differences that make `sub` not a subtype of `sup`
289///
290/// This function returns an empty list when `sub` is a subtype of `sup`
291pub fn relation_subtype_difference(
292    sub: &[ReprColumnType],
293    sup: &[ReprColumnType],
294) -> Vec<ReprRelationTypeDifference> {
295    let mut diffs = Vec::new();
296
297    if sub.len() != sup.len() {
298        diffs.push(ReprRelationTypeDifference::Length {
299            len_sub: sub.len(),
300            len_sup: sup.len(),
301        });
302
303        // TODO(mgree) we could do an edit-distance computation to report more errors
304        return diffs;
305    }
306
307    diffs.extend(
308        sub.iter()
309            .zip_eq(sup.iter())
310            .enumerate()
311            .flat_map(|(col, (sub_ty, sup_ty))| {
312                column_subtype_difference(sub_ty, sup_ty)
313                    .into_iter()
314                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
315            }),
316    );
317
318    diffs
319}
320
321/// Returns a list of differences that make `sub` not a subtype of `sup`
322///
323/// This function returns an empty list when `sub` is a subtype of `sup`
324pub fn column_subtype_difference(
325    sub: &ReprColumnType,
326    sup: &ReprColumnType,
327) -> Vec<ReprColumnTypeDifference> {
328    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
329
330    if sub.nullable && !sup.nullable {
331        diffs.push(ReprColumnTypeDifference::Nullability {
332            sub: sub.clone(),
333            sup: sup.clone(),
334        });
335    }
336
337    diffs
338}
339
340/// Returns a list of differences that make `sub` not a subtype of `sup`
341///
342/// This function returns an empty list when `sub` is a subtype of `sup`
343pub fn scalar_subtype_difference(
344    sub: &ReprScalarType,
345    sup: &ReprScalarType,
346) -> Vec<ReprColumnTypeDifference> {
347    use ReprScalarType::*;
348
349    let mut diffs = Vec::new();
350
351    match (sub, sup) {
352        (
353            List {
354                element_type: sub_elt,
355                ..
356            },
357            List {
358                element_type: sup_elt,
359                ..
360            },
361        )
362        | (
363            Map {
364                value_type: sub_elt,
365                ..
366            },
367            Map {
368                value_type: sup_elt,
369                ..
370            },
371        )
372        | (
373            Range {
374                element_type: sub_elt,
375                ..
376            },
377            Range {
378                element_type: sup_elt,
379                ..
380            },
381        )
382        | (Array(sub_elt), Array(sup_elt)) => {
383            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
384            diffs.extend(
385                scalar_subtype_difference(sub_elt, sup_elt)
386                    .into_iter()
387                    .map(|diff| ReprColumnTypeDifference::ElementType {
388                        ctor: ctor.clone(),
389                        element_type: Box::new(diff),
390                    }),
391            );
392        }
393        (
394            Record {
395                fields: sub_fields, ..
396            },
397            Record {
398                fields: sup_fields, ..
399            },
400        ) => {
401            if sub_fields.len() != sup_fields.len() {
402                diffs.push(ReprColumnTypeDifference::NotSubtype {
403                    sub: sub.clone(),
404                    sup: sup.clone(),
405                });
406                return diffs;
407            }
408
409            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
410                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
411            }
412        }
413        (_, _) => {
414            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
415                diffs.push(ReprColumnTypeDifference::NotSubtype {
416                    sub: sub.clone(),
417                    sup: sup.clone(),
418                })
419            }
420        }
421    };
422
423    diffs
424}
425
426/// Unions `other` into `typ`, returning a list of differences on failure
427///
428/// This function returns an empty list when `typ` and `other` are a union
429pub fn scalar_union(
430    typ: &mut ReprScalarType,
431    other: &ReprScalarType,
432) -> Vec<ReprColumnTypeDifference> {
433    use ReprScalarType::*;
434
435    let mut diffs = Vec::new();
436
437    // precomputing to appease the borrow checker
438    let ctor = ReprScalarBaseType::from(&*typ);
439    match (typ, other) {
440        (
441            List {
442                element_type: typ_elt,
443            },
444            List {
445                element_type: other_elt,
446            },
447        )
448        | (
449            Map {
450                value_type: typ_elt,
451            },
452            Map {
453                value_type: other_elt,
454            },
455        )
456        | (
457            Range {
458                element_type: typ_elt,
459            },
460            Range {
461                element_type: other_elt,
462            },
463        )
464        | (Array(typ_elt), Array(other_elt)) => {
465            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
466            diffs.extend(
467                res.into_iter()
468                    .map(|diff| ReprColumnTypeDifference::ElementType {
469                        ctor: format!("{ctor:?}"),
470                        element_type: Box::new(diff),
471                    }),
472            );
473        }
474        (
475            Record { fields: typ_fields },
476            Record {
477                fields: other_fields,
478            },
479        ) => {
480            if typ_fields.len() != other_fields.len() {
481                diffs.push(ReprColumnTypeDifference::NotSubtype {
482                    sub: ReprScalarType::Record {
483                        fields: typ_fields.clone(),
484                    },
485                    sup: other.clone(),
486                });
487                return diffs;
488            }
489
490            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
491                diffs.extend(column_union(typ_ty, other_ty));
492            }
493        }
494        (typ, _) => {
495            if ctor != ReprScalarBaseType::from(other) {
496                diffs.push(ReprColumnTypeDifference::NotSubtype {
497                    sub: typ.clone(),
498                    sup: other.clone(),
499                })
500            }
501        }
502    };
503
504    diffs
505}
506
507/// Unions `other` into `typ`, returning a list of differences on failure
508///
509/// This function returns an empty list when `typ` and `other` are a union
510pub fn column_union(
511    typ: &mut ReprColumnType,
512    other: &ReprColumnType,
513) -> Vec<ReprColumnTypeDifference> {
514    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
515
516    if diffs.is_empty() {
517        typ.nullable |= other.nullable;
518    }
519
520    diffs
521}
522
523/// Returns true when it is safe to treat a `sub` row as an `sup` row
524///
525/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
526/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
527pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
528    if sub.len() != sup.len() {
529        return false;
530    }
531
532    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
533        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
534    })
535}
536
537/// Characterizes how a Datum differs from a ReprColumnType
538#[derive(Clone, Debug)]
539pub enum DatumTypeDifference {
540    /// Datum was null, but expected a non-null value (i.e., a ReprScalarType)
541    Null {
542        /// The representation scalar type that we expected
543        expected: ReprScalarType,
544    },
545    /// Datum's kind doesn't match the expected ReprScalarType
546    Mismatch {
547        /// The debug rendering of the Datum that we got (we don't store the Datum itself to avoid borrowing issues)
548        /// This debug rendering includes the DatumKind.
549        got_debug: String,
550        /// The representation scalar type that we expected
551        expected: ReprScalarType,
552    },
553    /// Datum's dimensions didn't match the type's dimensions
554    MismatchDimensions {
555        /// The type constructor (list, array, int2vector, etc.)
556        ctor: String,
557        /// The dimension of the datum that we got
558        got: usize,
559        /// The dimension of the type that we expected
560        expected: usize,
561    },
562    /// There was a nested difference in the element type of some container type
563    ElementType {
564        /// The type constructor (list, array, int2vector, etc.)
565        ctor: String,
566        /// The difference in the element type
567        element_type: Box<DatumTypeDifference>,
568    },
569}
570
571/// Computes the difference between this datum and the specified (representation) column type.
572///
573/// See [`Datum<'a>::is_instance_of`] for just getting a yes/no answer.
574///
575/// See [`Datum<'a>::is_instance_of_sql`] for comparing `Datum`s to `SqlColumnType`s.
576fn datum_difference_with_column_type(
577    datum: &Datum<'_>,
578    column_type: &ReprColumnType,
579) -> Result<(), DatumTypeDifference> {
580    fn difference_with_scalar_type(
581        datum: &Datum<'_>,
582        scalar_type: &ReprScalarType,
583    ) -> Result<(), DatumTypeDifference> {
584        fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
585            Err(DatumTypeDifference::Mismatch {
586                // this will be redacted as appropriate
587                got_debug: format!("{got:?}"),
588                expected: expected.clone(),
589            })
590        }
591
592        if let ReprScalarType::Jsonb = scalar_type {
593            // json type checking
594            match datum {
595                Datum::Dummy => Ok(()), // allow dummys (unlike is_instance_of)
596                Datum::Null => Err(DatumTypeDifference::Null {
597                    expected: ReprScalarType::Jsonb,
598                }),
599                Datum::JsonNull
600                | Datum::False
601                | Datum::True
602                | Datum::Numeric(_)
603                | Datum::String(_) => Ok(()),
604                Datum::List(list) => {
605                    for elem in list.iter() {
606                        difference_with_scalar_type(&elem, scalar_type)?;
607                    }
608                    Ok(())
609                }
610                Datum::Map(dict) => {
611                    for (_, val) in dict.iter() {
612                        difference_with_scalar_type(&val, scalar_type)?;
613                    }
614                    Ok(())
615                }
616                _ => mismatch(datum, scalar_type),
617            }
618        } else {
619            fn element_type_difference(
620                ctor: &str,
621                element_type: DatumTypeDifference,
622            ) -> DatumTypeDifference {
623                DatumTypeDifference::ElementType {
624                    ctor: ctor.to_string(),
625                    element_type: Box::new(element_type),
626                }
627            }
628            match (datum, scalar_type) {
629                (Datum::Dummy, _) => Ok(()), // allow dummys (unlike is_instance_of)
630                (Datum::Null, _) => Err(DatumTypeDifference::Null {
631                    expected: scalar_type.clone(),
632                }),
633                (Datum::False, ReprScalarType::Bool) => Ok(()),
634                (Datum::False, _) => mismatch(datum, scalar_type),
635                (Datum::True, ReprScalarType::Bool) => Ok(()),
636                (Datum::True, _) => mismatch(datum, scalar_type),
637                (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
638                (Datum::Int16(_), _) => mismatch(datum, scalar_type),
639                (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
640                (Datum::Int32(_), _) => mismatch(datum, scalar_type),
641                (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
642                (Datum::Int64(_), _) => mismatch(datum, scalar_type),
643                (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
644                (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
645                (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
646                (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
647                (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
648                (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
649                (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
650                (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
651                (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
652                (Datum::Float32(_), _) => mismatch(datum, scalar_type),
653                (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
654                (Datum::Float64(_), _) => mismatch(datum, scalar_type),
655                (Datum::Date(_), ReprScalarType::Date) => Ok(()),
656                (Datum::Date(_), _) => mismatch(datum, scalar_type),
657                (Datum::Time(_), ReprScalarType::Time) => Ok(()),
658                (Datum::Time(_), _) => mismatch(datum, scalar_type),
659                (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
660                (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
661                (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
662                (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
663                (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
664                (Datum::Interval(_), _) => mismatch(datum, scalar_type),
665                (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
666                (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
667                (Datum::String(_), ReprScalarType::String) => Ok(()),
668                (Datum::String(_), _) => mismatch(datum, scalar_type),
669                (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
670                (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
671                (Datum::Array(array), ReprScalarType::Array(t)) => {
672                    for e in array.elements().iter() {
673                        if let Datum::Null = e {
674                            continue;
675                        }
676
677                        difference_with_scalar_type(&e, t)
678                            .map_err(|e| element_type_difference("array", e))?;
679                    }
680                    Ok(())
681                }
682                (Datum::Array(array), ReprScalarType::Int2Vector) => {
683                    if array.dims().len() != 1 {
684                        return Err(DatumTypeDifference::MismatchDimensions {
685                            ctor: "int2vector".to_string(),
686                            got: array.dims().len(),
687                            expected: 1,
688                        });
689                    }
690
691                    for e in array.elements().iter() {
692                        difference_with_scalar_type(&e, &ReprScalarType::Int16)
693                            .map_err(|e| element_type_difference("int2vector", e))?;
694                    }
695
696                    Ok(())
697                }
698                (Datum::Array(_), _) => mismatch(datum, scalar_type),
699                (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
700                    for e in list.iter() {
701                        if let Datum::Null = e {
702                            continue;
703                        }
704
705                        difference_with_scalar_type(&e, element_type)
706                            .map_err(|e| element_type_difference("list", e))?;
707                    }
708                    Ok(())
709                }
710                (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
711                    let len = list.iter().count();
712                    if len != fields.len() {
713                        return Err(DatumTypeDifference::MismatchDimensions {
714                            ctor: "record".to_string(),
715                            got: len,
716                            expected: fields.len(),
717                        });
718                    }
719
720                    for (e, t) in list.iter().zip_eq(fields) {
721                        if let Datum::Null = e {
722                            if t.nullable {
723                                continue;
724                            } else {
725                                return Err(DatumTypeDifference::Null {
726                                    expected: t.scalar_type.clone(),
727                                });
728                            }
729                        }
730
731                        difference_with_scalar_type(&e, &t.scalar_type)
732                            .map_err(|e| element_type_difference("record", e))?;
733                    }
734                    Ok(())
735                }
736                (Datum::List(_), _) => mismatch(datum, scalar_type),
737                (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
738                    for (_, v) in map.iter() {
739                        if let Datum::Null = v {
740                            continue;
741                        }
742
743                        difference_with_scalar_type(&v, value_type)
744                            .map_err(|e| element_type_difference("map", e))?;
745                    }
746                    Ok(())
747                }
748                (Datum::Map(_), _) => mismatch(datum, scalar_type),
749                (Datum::JsonNull, _) => mismatch(datum, scalar_type),
750                (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
751                (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
752                (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
753                (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
754                (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
755                    match inner {
756                        None => Ok(()),
757                        Some(inner) => {
758                            if let Some(b) = inner.lower.bound {
759                                difference_with_scalar_type(&b.datum(), element_type)
760                                    .map_err(|e| element_type_difference("range", e))?;
761                            }
762                            if let Some(b) = inner.upper.bound {
763                                difference_with_scalar_type(&b.datum(), element_type)
764                                    .map_err(|e| element_type_difference("range", e))?;
765                            }
766                            Ok(())
767                        }
768                    }
769                }
770                (Datum::Range(_), _) => mismatch(datum, scalar_type),
771                (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
772                (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
773                (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
774                (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
775            }
776        }
777    }
778    if column_type.nullable {
779        if let Datum::Null = datum {
780            return Ok(());
781        }
782    }
783    difference_with_scalar_type(datum, &column_type.scalar_type)
784}
785
786fn row_difference_with_column_types<'a>(
787    source: &'a MirRelationExpr,
788    datums: &Vec<Datum<'_>>,
789    column_types: &[ReprColumnType],
790) -> Result<(), TypeError<'a>> {
791    // correct length
792    if datums.len() != column_types.len() {
793        return Err(TypeError::BadConstantRowLen {
794            source,
795            got: datums.len(),
796            expected: column_types.to_vec(),
797        });
798    }
799
800    // correct types
801    let mut mismatches = Vec::new();
802    for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
803        if let Err(e) = datum_difference_with_column_type(d, ty) {
804            mismatches.push((i, e));
805        }
806    }
807    if !mismatches.is_empty() {
808        return Err(TypeError::BadConstantRow {
809            source,
810            mismatches,
811            expected: column_types.to_vec(),
812        });
813    }
814
815    Ok(())
816}
817/// Check that the visible type of each query has not been changed
818#[derive(Debug)]
819pub struct Typecheck {
820    /// The known types of the queries so far
821    ctx: SharedContext,
822    /// Whether or not this is the first run of the transform
823    disallow_new_globals: bool,
824    /// Whether or not to be strict about join equivalences having the same nullability
825    strict_join_equivalences: bool,
826    /// Whether or not to disallow dummy values
827    disallow_dummy: bool,
828    /// Recursion guard for checked recursion
829    recursion_guard: RecursionGuard,
830}
831
832impl CheckedRecursion for Typecheck {
833    fn recursion_guard(&self) -> &RecursionGuard {
834        &self.recursion_guard
835    }
836}
837
838impl Typecheck {
839    /// Creates a typechecking consistency checking pass using a given shared context
840    pub fn new(ctx: SharedContext) -> Self {
841        Self {
842            ctx,
843            disallow_new_globals: false,
844            strict_join_equivalences: false,
845            disallow_dummy: false,
846            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
847        }
848    }
849
850    /// New non-transient global IDs will be treated as an error
851    ///
852    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
853    pub fn disallow_new_globals(mut self) -> Self {
854        self.disallow_new_globals = true;
855        self
856    }
857
858    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
859    ///
860    /// Only turn this on before `JoinImplementation`
861    pub fn strict_join_equivalences(mut self) -> Self {
862        self.strict_join_equivalences = true;
863
864        self
865    }
866
867    /// Disallow dummy values
868    pub fn disallow_dummy(mut self) -> Self {
869        self.disallow_dummy = true;
870        self
871    }
872
873    /// Returns the type of a relation expression or a type error.
874    ///
875    /// This function is careful to check validity, not just find out the type.
876    ///
877    /// It should be linear in the size of the AST.
878    ///
879    /// ??? should we also compute keys and return a `ReprRelationType`?
880    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
881    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
882    ///   reliably keep it stable across transformations.
883    pub fn typecheck<'a>(
884        &self,
885        expr: &'a MirRelationExpr,
886        ctx: &Context,
887    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
888        use MirRelationExpr::*;
889
890        self.checked_recur(|tc| match expr {
891            Constant { typ, rows } => {
892                if let Ok(rows) = rows {
893                    for (row, _id) in rows {
894                        let datums = row.unpack();
895
896                        row_difference_with_column_types(expr, &datums, &typ.column_types.iter().map(ReprColumnType::from).collect_vec())?;
897
898                        if self.disallow_dummy && datums.iter().any(|d| d == &mz_repr::Datum::Dummy) {
899                            return Err(TypeError::DisallowedDummy {
900                                source: expr,
901                            });
902                        }
903                    }
904                }
905
906                Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec())
907            }
908            Get { typ, id, .. } => {
909                if let Id::Global(_global_id) = id {
910                    if !ctx.contains_key(id) {
911                        // TODO(mgree) pass QueryContext through to check these types
912                        return Ok(typ.column_types.iter().map(ReprColumnType::from).collect_vec());
913                    }
914                }
915
916                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
917                    source: expr,
918                    id: id.clone(),
919                    typ: ReprRelationType::from(typ),
920                })?;
921
922                let column_types = typ.column_types.iter().map(ReprColumnType::from).collect_vec();
923
924                // covariant: the ascribed type must be a subtype of the actual type in the context
925                let diffs = relation_subtype_difference(&column_types, ctx_typ).into_iter().flat_map(|diff| diff.ignore_nullability()).collect::<Vec<_>>();
926
927                if !diffs.is_empty() {
928                    return Err(TypeError::MismatchColumns {
929                        source: expr,
930                        got: column_types,
931                        expected: ctx_typ.clone(),
932                        diffs,
933                        message: "annotation did not match context type".to_string(),
934                    });
935                }
936
937                Ok(column_types)
938            }
939            Project { input, outputs } => {
940                let t_in = tc.typecheck(input, ctx)?;
941
942                for x in outputs {
943                    if *x >= t_in.len() {
944                        return Err(TypeError::BadProject {
945                            source: expr,
946                            got: outputs.clone(),
947                            input_type: t_in,
948                        });
949                    }
950                }
951
952                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
953            }
954            Map { input, scalars } => {
955                let mut t_in = tc.typecheck(input, ctx)?;
956
957                for scalar_expr in scalars.iter() {
958                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
959
960                    if self.disallow_dummy && scalar_expr.contains_dummy() {
961                        return Err(TypeError::DisallowedDummy {
962                            source: expr,
963                        });
964                    }
965                }
966
967                Ok(t_in)
968            }
969            FlatMap { input, func, exprs } => {
970                let mut t_in = tc.typecheck(input, ctx)?;
971
972                let mut t_exprs = Vec::with_capacity(exprs.len());
973                for scalar_expr in exprs {
974                    t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
975
976                    if self.disallow_dummy && scalar_expr.contains_dummy() {
977                        return Err(TypeError::DisallowedDummy {
978                            source: expr,
979                        });
980                    }
981                }
982                // TODO(mgree) check t_exprs agrees with `func`'s input type
983
984                let t_out = func.output_type().column_types.iter().map(ReprColumnType::from).collect_vec();
985
986                // FlatMap extends the existing columns
987                t_in.extend(t_out);
988                Ok(t_in)
989            }
990            Filter { input, predicates } => {
991                let mut t_in = tc.typecheck(input, ctx)?;
992
993                // Set as nonnull any columns where null values would cause
994                // any predicate to evaluate to null.
995                for column in non_nullable_columns(predicates) {
996                    t_in[column].nullable = false;
997                }
998
999                for scalar_expr in predicates {
1000                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1001
1002                    // filter condition must be boolean
1003                    // ignoring nullability: null is treated as false
1004                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1005                    if t.scalar_type != ReprScalarType::Bool {
1006                        let sub = t.scalar_type.clone();
1007
1008                        return Err(TypeError::MismatchColumn {
1009                            source: expr,
1010                            got: t,
1011                            expected: ReprColumnType {
1012                                scalar_type: ReprScalarType::Bool,
1013                                nullable: true,
1014                            },
1015                            diffs: vec![ReprColumnTypeDifference::NotSubtype { sub, sup: ReprScalarType::Bool }],
1016                            message: "expected boolean condition".to_string(),
1017                        });
1018                    }
1019
1020                    if self.disallow_dummy && scalar_expr.contains_dummy() {
1021                        return Err(TypeError::DisallowedDummy {
1022                            source: expr,
1023                        });
1024                    }
1025                }
1026
1027                Ok(t_in)
1028            }
1029            Join {
1030                inputs,
1031                equivalences,
1032                implementation,
1033            } => {
1034                let mut t_in_global = Vec::new();
1035                let mut t_in_local = vec![Vec::new(); inputs.len()];
1036
1037                for (i, input) in inputs.iter().enumerate() {
1038                    let input_t = tc.typecheck(input, ctx)?;
1039                    t_in_global.extend(input_t.clone());
1040                    t_in_local[i] = input_t;
1041                }
1042
1043                for eq_class in equivalences {
1044                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1045
1046                    let mut all_nullable = true;
1047
1048                    for scalar_expr in eq_class {
1049                        // Note: the equivalences have global column references
1050                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1051
1052                        if !t_expr.nullable {
1053                            all_nullable = false;
1054                        }
1055
1056                        if let Some(t_first) = t_exprs.get(0) {
1057                            let diffs = scalar_subtype_difference(&t_expr.scalar_type, &t_first.scalar_type);
1058                            if !diffs.is_empty() {
1059                                return Err(TypeError::MismatchColumn {
1060                                    source: expr,
1061                                    got: t_expr,
1062                                    expected: t_first.clone(),
1063                                    diffs,
1064                                    message: "equivalence class members have different scalar types".to_string(),
1065                                });
1066                            }
1067
1068                            // equivalences may or may not match on nullability
1069                            // before JoinImplementation runs, nullability should match.
1070                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
1071                            if self.strict_join_equivalences {
1072                                if t_expr.nullable != t_first.nullable {
1073                                    let sub = t_expr.clone();
1074                                    let sup = t_first.clone();
1075
1076                                    let err = TypeError::MismatchColumn {
1077                                        source: expr,
1078                                        got: t_expr.clone(),
1079                                        expected: t_first.clone(),
1080                                        diffs: vec![ReprColumnTypeDifference::Nullability { sub, sup }],
1081                                        message: "equivalence class members have different nullability (and join equivalence checking is strict)".to_string(),
1082                                    };
1083
1084                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1085                                    ::tracing::debug!("{err}");
1086                                }
1087                            }
1088                        }
1089
1090                        if self.disallow_dummy && scalar_expr.contains_dummy() {
1091                            return Err(TypeError::DisallowedDummy {
1092                                source: expr,
1093                            });
1094                        }
1095
1096                        t_exprs.push(t_expr);
1097                    }
1098
1099                    if self.strict_join_equivalences && all_nullable {
1100                        let err = TypeError::BadJoinEquivalence {
1101                            source: expr,
1102                            got: t_exprs,
1103                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1104                        };
1105
1106                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1107                        ::tracing::debug!("{err}");
1108                    }
1109                }
1110
1111                // check that the join implementation is consistent
1112                match implementation {
1113                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
1114                        if let Some(key) = first_key {
1115                            for k in key {
1116                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1117                            }
1118                        }
1119
1120                        for (idx, key, _) in others {
1121                            for k in key {
1122                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1123                            }
1124                        }
1125                    }
1126                    JoinImplementation::DeltaQuery(plans) => {
1127                        for plan in plans {
1128                            for (idx, key, _) in plan {
1129                                for k in key {
1130                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1131                                }
1132                            }
1133                        }
1134                    }
1135                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1136                        let typ: Vec<ReprColumnType> = key
1137                            .iter()
1138                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1139                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1140
1141                        for row in consts {
1142                            let datums = row.unpack();
1143
1144                            row_difference_with_column_types(expr, &datums, &typ)?;
1145                        }
1146                    }
1147                    JoinImplementation::Unimplemented => (),
1148                }
1149
1150                Ok(t_in_global)
1151            }
1152            Reduce {
1153                input,
1154                group_key,
1155                aggregates,
1156                monotonic: _,
1157                expected_group_size: _,
1158            } => {
1159                let t_in = tc.typecheck(input, ctx)?;
1160
1161                let mut t_out = group_key
1162                    .iter()
1163                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1164                    .collect::<Result<Vec<_>, _>>()?;
1165
1166                    if self.disallow_dummy && group_key.iter().any(|scalar_expr| scalar_expr.contains_dummy()) {
1167                        return Err(TypeError::DisallowedDummy {
1168                            source: expr,
1169                        });
1170                    }
1171
1172                for agg in aggregates {
1173                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1174                }
1175
1176                Ok(t_out)
1177            }
1178            TopK {
1179                input,
1180                group_key,
1181                order_key,
1182                limit: _,
1183                offset: _,
1184                monotonic: _,
1185                expected_group_size: _,
1186            } => {
1187                let t_in = tc.typecheck(input, ctx)?;
1188
1189                for &k in group_key {
1190                    if k >= t_in.len() {
1191                        return Err(TypeError::BadTopKGroupKey {
1192                            source: expr,
1193                            k,
1194                            input_type: t_in,
1195                        });
1196                    }
1197                }
1198
1199                for order in order_key {
1200                    if order.column >= t_in.len() {
1201                        return Err(TypeError::BadTopKOrdering {
1202                            source: expr,
1203                            order: order.clone(),
1204                            input_type: t_in,
1205                        });
1206                    }
1207                }
1208
1209                Ok(t_in)
1210            }
1211            Negate { input } => tc.typecheck(input, ctx),
1212            Threshold { input } => tc.typecheck(input, ctx),
1213            Union { base, inputs } => {
1214                let mut t_base = tc.typecheck(base, ctx)?;
1215
1216                for input in inputs {
1217                    let t_input = tc.typecheck(input, ctx)?;
1218
1219                    let len_sub = t_base.len();
1220                    let len_sup = t_input.len();
1221                    if len_sub != len_sup {
1222                        return Err(TypeError::MismatchColumns {
1223                            source: expr,
1224                            got: t_base.clone(),
1225                            expected: t_input,
1226                            diffs: vec![ReprRelationTypeDifference::Length {
1227                                len_sub,
1228                                len_sup,
1229                            }],
1230                            message: "Union branches have different numbers of columns".to_string(),
1231                        });
1232                    }
1233
1234                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1235                        let diffs = column_union(base_col, &input_col);
1236                        if !diffs.is_empty() {
1237                            return Err(TypeError::MismatchColumn {
1238                                    source: expr,
1239                                    got: input_col,
1240                                    expected: base_col.clone(),
1241                                    diffs,
1242                                    message:
1243                                        "couldn't compute union of column types in Union"
1244                                    .to_string(),
1245                            });
1246                        }
1247
1248                    }
1249                }
1250
1251                Ok(t_base)
1252            }
1253            Let { id, value, body } => {
1254                let t_value = tc.typecheck(value, ctx)?;
1255
1256                let binding = Id::Local(*id);
1257                if ctx.contains_key(&binding) {
1258                    return Err(TypeError::Shadowing {
1259                        source: expr,
1260                        id: binding,
1261                    });
1262                }
1263
1264                let mut body_ctx = ctx.clone();
1265                body_ctx.insert(Id::Local(*id), t_value);
1266
1267                tc.typecheck(body, &body_ctx)
1268            }
1269            LetRec { ids, values, body, limits: _ } => {
1270                if ids.len() != values.len() {
1271                    return Err(TypeError::BadLetRecBindings { source: expr });
1272                }
1273
1274                // temporary hack: steal info from the Gets inside to learn the expected types
1275                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1276                let mut ctx = ctx.clone();
1277                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1278                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1279                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1280                }
1281
1282                for (id, value) in ids.iter().zip_eq(values.iter()) {
1283                    let typ = tc.typecheck(value, &ctx)?;
1284
1285                    let id = Id::Local(id.clone());
1286                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1287                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1288                            // we expect an EXACT match, but don't care about nullability
1289                            let diffs = column_union(base_col, &input_col);
1290                            if !diffs.is_empty() {
1291                                 return Err(TypeError::MismatchColumn {
1292                                        source: expr,
1293                                        got: input_col,
1294                                        expected: base_col.clone(),
1295                                        diffs,
1296                                        message:
1297                                            "couldn't compute union of column types in LetRec"
1298                                        .to_string(),
1299                                    })
1300                            }
1301                        }
1302                    } else {
1303                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1304                        ctx.insert(id, typ);
1305                    }
1306                }
1307
1308                tc.typecheck(body, &ctx)
1309            }
1310            ArrangeBy { input, keys } => {
1311                let t_in = tc.typecheck(input, ctx)?;
1312
1313                for key in keys {
1314                    for k in key {
1315                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1316                    }
1317                }
1318
1319                Ok(t_in)
1320            }
1321        })
1322    }
1323
1324    /// Traverses a term to collect the types of given ids.
1325    ///
1326    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1327    fn collect_recursive_variable_types<'a>(
1328        &self,
1329        expr: &'a MirRelationExpr,
1330        ids: &[LocalId],
1331        ctx: &mut Context,
1332    ) -> Result<(), TypeError<'a>> {
1333        use MirRelationExpr::*;
1334
1335        self.checked_recur(|tc| {
1336            match expr {
1337                Get {
1338                    id: Id::Local(id),
1339                    typ,
1340                    ..
1341                } => {
1342                    if !ids.contains(id) {
1343                        return Ok(());
1344                    }
1345
1346                    let id = Id::Local(id.clone());
1347                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1348                        let typ = typ
1349                            .column_types
1350                            .iter()
1351                            .map(ReprColumnType::from)
1352                            .collect_vec();
1353
1354                        if ctx_typ.len() != typ.len() {
1355                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1356
1357                            return Err(TypeError::MismatchColumns {
1358                                source: expr,
1359                                got: typ,
1360                                expected: ctx_typ.clone(),
1361                                diffs,
1362                                message: "environment and type annotation did not match"
1363                                    .to_string(),
1364                            });
1365                        }
1366
1367                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1368                            let diffs = column_union(base_col, &input_col);
1369                            if !diffs.is_empty() {
1370                                return Err(TypeError::MismatchColumn {
1371                                    source: expr,
1372                                    got: input_col,
1373                                    expected: base_col.clone(),
1374                                    diffs,
1375                                    message:
1376                                        "couldn't compute union of column types in Get and context"
1377                                            .to_string(),
1378                                });
1379                            }
1380                        }
1381                    } else {
1382                        ctx.insert(
1383                            id,
1384                            typ.column_types
1385                                .iter()
1386                                .map(ReprColumnType::from)
1387                                .collect_vec(),
1388                        );
1389                    }
1390                }
1391                Get {
1392                    id: Id::Global(..), ..
1393                }
1394                | Constant { .. } => (),
1395                Let { id, value, body } => {
1396                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1397
1398                    // we've shadowed the id
1399                    if ids.contains(id) {
1400                        return Err(TypeError::Shadowing {
1401                            source: expr,
1402                            id: Id::Local(*id),
1403                        });
1404                    }
1405
1406                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1407                }
1408                LetRec {
1409                    ids: inner_ids,
1410                    values,
1411                    body,
1412                    limits: _,
1413                } => {
1414                    for inner_id in inner_ids {
1415                        if ids.contains(inner_id) {
1416                            return Err(TypeError::Shadowing {
1417                                source: expr,
1418                                id: Id::Local(*inner_id),
1419                            });
1420                        }
1421                    }
1422
1423                    for value in values {
1424                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1425                    }
1426
1427                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1428                }
1429                Project { input, .. }
1430                | Map { input, .. }
1431                | FlatMap { input, .. }
1432                | Filter { input, .. }
1433                | Reduce { input, .. }
1434                | TopK { input, .. }
1435                | Negate { input }
1436                | Threshold { input }
1437                | ArrangeBy { input, .. } => {
1438                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1439                }
1440                Join { inputs, .. } => {
1441                    for input in inputs {
1442                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1443                    }
1444                }
1445                Union { base, inputs } => {
1446                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1447
1448                    for input in inputs {
1449                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1450                    }
1451                }
1452            }
1453
1454            Ok(())
1455        })
1456    }
1457
1458    fn typecheck_scalar<'a>(
1459        &self,
1460        expr: &'a MirScalarExpr,
1461        source: &'a MirRelationExpr,
1462        column_types: &[ReprColumnType],
1463    ) -> Result<ReprColumnType, TypeError<'a>> {
1464        use MirScalarExpr::*;
1465
1466        self.checked_recur(|tc| match expr {
1467            Column(i, _) => match column_types.get(*i) {
1468                Some(ty) => Ok(ty.clone()),
1469                None => Err(TypeError::NoSuchColumn {
1470                    source,
1471                    expr,
1472                    col: *i,
1473                }),
1474            },
1475            Literal(row, typ) => {
1476                let typ = ReprColumnType::from(typ);
1477                if let Ok(row) = row {
1478                    let datums = row.unpack();
1479
1480                    row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1481                }
1482
1483                Ok(typ)
1484            }
1485            CallUnmaterializable(func) => Ok(ReprColumnType::from(&func.output_type())),
1486            CallUnary { expr, func } => {
1487                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1488                let typ_out = func.output_type(SqlColumnType::from_repr(&typ_in));
1489                Ok(ReprColumnType::from(&typ_out))
1490            }
1491            CallBinary { expr1, expr2, func } => {
1492                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1493                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1494                let typ_out = func.output_type(
1495                    SqlColumnType::from_repr(&typ_in1),
1496                    SqlColumnType::from_repr(&typ_in2),
1497                );
1498                Ok(ReprColumnType::from(&typ_out))
1499            }
1500            CallVariadic { exprs, func } => Ok(ReprColumnType::from(
1501                &func.output_type(
1502                    exprs
1503                        .iter()
1504                        .map(|e| {
1505                            tc.typecheck_scalar(e, source, column_types)
1506                                .map(|typ| SqlColumnType::from_repr(&typ))
1507                        })
1508                        .collect::<Result<Vec<_>, TypeError>>()?,
1509                ),
1510            )),
1511            If { cond, then, els } => {
1512                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1513
1514                // condition must be boolean
1515                // ignoring nullability: null is treated as false
1516                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1517                if cond_type.scalar_type != ReprScalarType::Bool {
1518                    let sub = cond_type.scalar_type.clone();
1519
1520                    return Err(TypeError::MismatchColumn {
1521                        source,
1522                        got: cond_type,
1523                        expected: ReprColumnType {
1524                            scalar_type: ReprScalarType::Bool,
1525                            nullable: true,
1526                        },
1527                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1528                            sub,
1529                            sup: ReprScalarType::Bool,
1530                        }],
1531                        message: "expected boolean condition".to_string(),
1532                    });
1533                }
1534
1535                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1536                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1537
1538                let diffs = column_union(&mut then_type, &else_type);
1539                if !diffs.is_empty() {
1540                    return Err(TypeError::MismatchColumn {
1541                        source,
1542                        got: then_type,
1543                        expected: else_type,
1544                        diffs,
1545                        message: "couldn't compute union of column types for If".to_string(),
1546                    });
1547                }
1548
1549                Ok(then_type)
1550            }
1551        })
1552    }
1553
1554    /// Typecheck an `AggregateExpr`
1555    pub fn typecheck_aggregate<'a>(
1556        &self,
1557        expr: &'a AggregateExpr,
1558        source: &'a MirRelationExpr,
1559        column_types: &[ReprColumnType],
1560    ) -> Result<ReprColumnType, TypeError<'a>> {
1561        self.checked_recur(|tc| {
1562            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1563
1564            // TODO check that t_in is actually acceptable for `func`
1565
1566            Ok(ReprColumnType::from(
1567                &expr.func.output_type(SqlColumnType::from_repr(&t_in)),
1568            ))
1569        })
1570    }
1571}
1572
1573/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1574///
1575/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1576macro_rules! type_error {
1577    ($severity:expr, $($arg:tt)+) => {{
1578        if $severity {
1579          soft_panic_or_log!($($arg)+);
1580        } else {
1581          ::tracing::debug!($($arg)+);
1582        }
1583    }}
1584}
1585
1586impl crate::Transform for Typecheck {
1587    fn name(&self) -> &'static str {
1588        "Typecheck"
1589    }
1590
1591    fn actually_perform_transform(
1592        &self,
1593        relation: &mut MirRelationExpr,
1594        transform_ctx: &mut crate::TransformCtx,
1595    ) -> Result<(), crate::TransformError> {
1596        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1597
1598        let expected = transform_ctx
1599            .global_id
1600            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1601
1602        if let Some(id) = transform_ctx.global_id {
1603            if self.disallow_new_globals
1604                && expected.is_none()
1605                && transform_ctx.global_id.is_some()
1606                && !id.is_transient()
1607            {
1608                type_error!(
1609                    false, // not severe
1610                    "type warning: new non-transient global id {id}\n{}",
1611                    relation.pretty()
1612                );
1613            }
1614        }
1615
1616        let got = self.typecheck(relation, &typecheck_ctx);
1617
1618        let humanizer = mz_repr::explain::DummyHumanizer;
1619
1620        match (got, expected) {
1621            (Ok(got), Some(expected)) => {
1622                let id = transform_ctx.global_id.unwrap();
1623
1624                // contravariant: global types can be updated
1625                let diffs = relation_subtype_difference(expected, &got);
1626                if !diffs.is_empty() {
1627                    // SEVERE only if got and expected have true differences, not just nullability
1628                    let severity = diffs
1629                        .iter()
1630                        .any(|diff| diff.clone().ignore_nullability().is_some());
1631
1632                    let err = TypeError::MismatchColumns {
1633                        source: relation,
1634                        got,
1635                        expected: expected.clone(),
1636                        diffs,
1637                        message: format!(
1638                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1639                        ),
1640                    };
1641
1642                    type_error!(severity, "type error in known global id {id}:\n{err}");
1643                }
1644            }
1645            (Ok(got), None) => {
1646                if let Some(id) = transform_ctx.global_id {
1647                    typecheck_ctx.insert(Id::Global(id), got);
1648                }
1649            }
1650            (Err(err), _) => {
1651                let (expected, binding) = match expected {
1652                    Some(expected) => {
1653                        let id = transform_ctx.global_id.unwrap();
1654                        (
1655                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1656                            format!("known global id {id}"),
1657                        )
1658                    }
1659                    None => ("".to_string(), "transient query".to_string()),
1660                };
1661
1662                type_error!(
1663                    true, // SEVERE: the transformed code is inconsistent
1664                    "type error in {binding}:\n{err}\n{expected}{}",
1665                    relation.pretty()
1666                );
1667            }
1668        }
1669
1670        Ok(())
1671    }
1672}
1673
1674/// Prints a type prettily with a given `ExprHumanizer`
1675pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1676where
1677    H: ExprHumanizer,
1678{
1679    let mut s = String::with_capacity(2 + 3 * cols.len());
1680
1681    s.push('(');
1682
1683    let mut it = cols.iter().peekable();
1684    while let Some(col) = it.next() {
1685        s.push_str(&humanizer.humanize_column_type_repr(col, false));
1686
1687        if it.peek().is_some() {
1688            s.push_str(", ");
1689        }
1690    }
1691
1692    s.push(')');
1693
1694    s
1695}
1696
1697impl ReprRelationTypeDifference {
1698    /// Pretty prints a type difference
1699    ///
1700    /// Always indents two spaces
1701    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1702    where
1703        H: ExprHumanizer,
1704    {
1705        use ReprRelationTypeDifference::*;
1706        match self {
1707            Length { len_sub, len_sup } => {
1708                writeln!(
1709                    f,
1710                    "  number of columns do not match ({len_sub} != {len_sup})"
1711                )
1712            }
1713            Column { col, diff } => {
1714                writeln!(f, "  column {col} differs:")?;
1715                diff.humanize(4, h, f)
1716            }
1717        }
1718    }
1719}
1720
1721impl ReprColumnTypeDifference {
1722    /// Pretty prints a type difference at a given indentation level
1723    pub fn humanize<H>(
1724        &self,
1725        indent: usize,
1726        h: &H,
1727        f: &mut std::fmt::Formatter<'_>,
1728    ) -> std::fmt::Result
1729    where
1730        H: ExprHumanizer,
1731    {
1732        use ReprColumnTypeDifference::*;
1733
1734        // indent
1735        write!(f, "{:indent$}", "")?;
1736
1737        match self {
1738            NotSubtype { sub, sup } => {
1739                let sub = h.humanize_scalar_type_repr(sub, false);
1740                let sup = h.humanize_scalar_type_repr(sup, false);
1741
1742                writeln!(f, "{sub} is a not a subtype of {sup}")
1743            }
1744            Nullability { sub, sup } => {
1745                let sub = h.humanize_column_type_repr(sub, false);
1746                let sup = h.humanize_column_type_repr(sup, false);
1747
1748                writeln!(f, "{sub} is nullable but {sup} is not")
1749            }
1750            ElementType { ctor, element_type } => {
1751                writeln!(f, "{ctor} element types differ:")?;
1752
1753                element_type.humanize(indent + 2, h, f)
1754            }
1755            RecordMissingFields { missing } => {
1756                write!(f, "missing column fields:")?;
1757                for col in missing {
1758                    write!(f, " {col}")?;
1759                }
1760                f.write_char('\n')
1761            }
1762            RecordFields { fields } => {
1763                writeln!(f, "{} record fields differ:", fields.len())?;
1764
1765                for (i, diff) in fields.iter().enumerate() {
1766                    writeln!(f, "{:indent$}  field {i}:", "")?;
1767                    diff.humanize(indent + 4, h, f)?;
1768                }
1769                Ok(())
1770            }
1771        }
1772    }
1773}
1774
1775impl DatumTypeDifference {
1776    /// Pretty prints a type difference at a given indentation level
1777    pub fn humanize<H>(
1778        &self,
1779        indent: usize,
1780        h: &H,
1781        f: &mut std::fmt::Formatter<'_>,
1782    ) -> std::fmt::Result
1783    where
1784        H: ExprHumanizer,
1785    {
1786        // indent
1787        write!(f, "{:indent$}", "")?;
1788
1789        match self {
1790            DatumTypeDifference::Null { expected } => {
1791                let expected = h.humanize_scalar_type_repr(expected, false);
1792                writeln!(
1793                    f,
1794                    "unexpected null, expected representation type {expected}"
1795                )?
1796            }
1797            DatumTypeDifference::Mismatch {
1798                got_debug,
1799                expected,
1800            } => {
1801                let expected = h.humanize_scalar_type_repr(expected, false);
1802                // NB `got_debug` will be redacted as appropriate
1803                writeln!(
1804                    f,
1805                    "got datum {got_debug}, expected representation type {expected}"
1806                )?;
1807            }
1808            DatumTypeDifference::MismatchDimensions {
1809                ctor,
1810                got,
1811                expected,
1812            } => {
1813                writeln!(
1814                    f,
1815                    "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1816                )?;
1817            }
1818            DatumTypeDifference::ElementType { ctor, element_type } => {
1819                writeln!(f, "{ctor} element types differ:")?;
1820                element_type.humanize(indent + 4, h, f)?;
1821            }
1822        }
1823
1824        Ok(())
1825    }
1826}
1827
1828/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1829#[allow(missing_debug_implementations)]
1830pub struct TypeErrorHumanizer<'a, 'b, H>
1831where
1832    H: ExprHumanizer,
1833{
1834    err: &'a TypeError<'a>,
1835    humanizer: &'b H,
1836}
1837
1838impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1839where
1840    H: ExprHumanizer,
1841{
1842    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1843    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1844        Self { err, humanizer }
1845    }
1846}
1847
1848impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1849where
1850    H: ExprHumanizer,
1851{
1852    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1853        self.err.humanize(self.humanizer, f)
1854    }
1855}
1856
1857impl<'a> std::fmt::Display for TypeError<'a> {
1858    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1859        TypeErrorHumanizer {
1860            err: self,
1861            humanizer: &DummyHumanizer,
1862        }
1863        .fmt(f)
1864    }
1865}
1866
1867impl<'a> TypeError<'a> {
1868    /// The source of the type error
1869    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1870        use TypeError::*;
1871        match self {
1872            Unbound { source, .. }
1873            | NoSuchColumn { source, .. }
1874            | MismatchColumn { source, .. }
1875            | MismatchColumns { source, .. }
1876            | BadConstantRowLen { source, .. }
1877            | BadConstantRow { source, .. }
1878            | BadProject { source, .. }
1879            | BadJoinEquivalence { source, .. }
1880            | BadTopKGroupKey { source, .. }
1881            | BadTopKOrdering { source, .. }
1882            | BadLetRecBindings { source }
1883            | Shadowing { source, .. }
1884            | DisallowedDummy { source, .. } => Some(source),
1885            Recursion { .. } => None,
1886        }
1887    }
1888
1889    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1890    where
1891        H: ExprHumanizer,
1892    {
1893        if let Some(source) = self.source() {
1894            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1895        }
1896
1897        use TypeError::*;
1898        match self {
1899            Unbound { source: _, id, typ } => {
1900                let typ = columns_pretty(&typ.column_types, humanizer);
1901                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1902            }
1903            NoSuchColumn {
1904                source: _,
1905                expr,
1906                col,
1907            } => writeln!(f, "{expr} references non-existent column {col}")?,
1908            MismatchColumn {
1909                source: _,
1910                got,
1911                expected,
1912                diffs,
1913                message,
1914            } => {
1915                let got = humanizer.humanize_column_type_repr(got, false);
1916                let expected = humanizer.humanize_column_type_repr(expected, false);
1917                writeln!(
1918                    f,
1919                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1920                )?;
1921
1922                for diff in diffs {
1923                    diff.humanize(2, humanizer, f)?;
1924                }
1925            }
1926            MismatchColumns {
1927                source: _,
1928                got,
1929                expected,
1930                diffs,
1931                message,
1932            } => {
1933                let got = columns_pretty(got, humanizer);
1934                let expected = columns_pretty(expected, humanizer);
1935
1936                writeln!(
1937                    f,
1938                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1939                )?;
1940
1941                for diff in diffs {
1942                    diff.humanize(humanizer, f)?;
1943                }
1944            }
1945            BadConstantRowLen {
1946                source: _,
1947                got,
1948                expected,
1949            } => {
1950                let expected = columns_pretty(expected, humanizer);
1951                writeln!(
1952                    f,
1953                    "bad constant row\n      row has length {got}\nexpected row of type {expected}"
1954                )?
1955            }
1956            BadConstantRow {
1957                source: _,
1958                mismatches,
1959                expected,
1960            } => {
1961                let expected = columns_pretty(expected, humanizer);
1962
1963                let num_mismatches = mismatches.len();
1964                let plural = if num_mismatches == 1 { "" } else { "es" };
1965                writeln!(
1966                    f,
1967                    "bad constant row\n      got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1968                )?;
1969
1970                if num_mismatches > 0 {
1971                    writeln!(f, "")?;
1972                    for (col, diff) in mismatches.iter() {
1973                        writeln!(f, "      column #{col}:")?;
1974                        diff.humanize(8, humanizer, f)?;
1975                    }
1976                }
1977            }
1978            BadProject {
1979                source: _,
1980                got,
1981                input_type,
1982            } => {
1983                let input_type = columns_pretty(input_type, humanizer);
1984
1985                writeln!(
1986                    f,
1987                    "projection of non-existant columns {got:?} from type {input_type}"
1988                )?
1989            }
1990            BadJoinEquivalence {
1991                source: _,
1992                got,
1993                message,
1994            } => {
1995                let got = columns_pretty(got, humanizer);
1996
1997                writeln!(f, "bad join equivalence {got}: {message}")?
1998            }
1999            BadTopKGroupKey {
2000                source: _,
2001                k,
2002                input_type,
2003            } => {
2004                let input_type = columns_pretty(input_type, humanizer);
2005
2006                writeln!(
2007                    f,
2008                    "TopK group key component references invalid column {k} in columns: {input_type}"
2009                )?
2010            }
2011            BadTopKOrdering {
2012                source: _,
2013                order,
2014                input_type,
2015            } => {
2016                let col = order.column;
2017                let num_cols = input_type.len();
2018                let are = if num_cols == 1 { "is" } else { "are" };
2019                let s = if num_cols == 1 { "" } else { "s" };
2020                let input_type = columns_pretty(input_type, humanizer);
2021
2022                // TODO(cloud#8196)
2023                let mode = HumanizedExplain::new(false);
2024                let order = mode.expr(order, None);
2025
2026                writeln!(
2027                    f,
2028                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2029                )?
2030            }
2031            BadLetRecBindings { source: _ } => {
2032                writeln!(f, "LetRec ids and definitions don't line up")?
2033            }
2034            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2035            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2036            Recursion { error } => writeln!(f, "{error}")?,
2037        }
2038
2039        Ok(())
2040    }
2041}
2042
2043#[cfg(test)]
2044mod tests {
2045    use mz_ore::{assert_err, assert_ok};
2046    use mz_repr::{arb_datum, arb_datum_for_column};
2047    use proptest::prelude::*;
2048
2049    use super::*;
2050
2051    #[mz_ore::test]
2052    fn test_datum_type_difference() {
2053        let datum = Datum::Int16(1);
2054
2055        assert_ok!(datum_difference_with_column_type(
2056            &datum,
2057            &ReprColumnType {
2058                scalar_type: ReprScalarType::Int16,
2059                nullable: true,
2060            }
2061        ));
2062
2063        assert_err!(datum_difference_with_column_type(
2064            &datum,
2065            &ReprColumnType {
2066                scalar_type: ReprScalarType::Int32,
2067                nullable: false,
2068            }
2069        ));
2070    }
2071
2072    proptest! {
2073        #![proptest_config(ProptestConfig { cases: 5000, max_global_rejects: 2500, ..Default::default() })]
2074        #[mz_ore::test]
2075        #[cfg_attr(miri, ignore)]
2076        fn datum_type_difference_with_instance_of_on_valid_data((src, datum) in any::<SqlColumnType>().prop_flat_map(|src| {
2077            let datum = arb_datum_for_column(src.clone());
2078            (Just(src), datum) }
2079        )) {
2080            let typ = ReprColumnType::from(&src);
2081            let datum = Datum::from(&datum);
2082
2083            if datum.contains_dummy() {
2084                return Err(TestCaseError::reject("datum contains a dummy"));
2085            }
2086
2087            let diff = datum_difference_with_column_type(&datum, &typ);
2088            if datum.is_instance_of(&typ) {
2089                assert_ok!(diff);
2090            } else {
2091                assert_err!(diff);
2092            }
2093        }
2094    }
2095
2096    proptest! {
2097        // We run many cases because the data are _random_, and we want to be sure
2098        // that we have covered sufficient cases. We drop dummy data (behavior is different!) so we allow many more global rejects.
2099        #![proptest_config(ProptestConfig::with_cases(10000))]
2100        #[mz_ore::test]
2101        #[cfg_attr(miri, ignore)]
2102        fn datum_type_difference_agrees_with_is_instance_of_on_random_data(src in any::<SqlColumnType>(), datum in arb_datum(false)) {
2103            let typ = ReprColumnType::from(&src);
2104            let datum = Datum::from(&datum);
2105
2106            assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2107
2108            let diff = datum_difference_with_column_type(&datum, &typ);
2109            if datum.is_instance_of(&typ) {
2110                assert_ok!(diff);
2111            } else {
2112                assert_err!(diff);
2113            }
2114        }
2115    }
2116
2117    #[mz_ore::test]
2118    fn datum_type_difference_github_10039() {
2119        let typ = ReprColumnType {
2120            scalar_type: ReprScalarType::Record {
2121                fields: Box::new([ReprColumnType {
2122                    scalar_type: ReprScalarType::UInt32,
2123                    nullable: false,
2124                }]),
2125            },
2126            nullable: false,
2127        };
2128
2129        let mut row = mz_repr::Row::default();
2130        row.packer()
2131            .push_list(std::iter::once(mz_repr::Datum::Null));
2132        let datum = row.unpack_first();
2133
2134        assert!(!datum.is_instance_of(&typ));
2135        let diff = datum_difference_with_column_type(&datum, &typ);
2136        assert_err!(diff);
2137    }
2138}