Skip to main content

mz_transform/
typecheck.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Check that the visible type of each query has not been changed
11
12use std::collections::BTreeMap;
13use std::fmt::Write;
14use std::sync::{Arc, Mutex};
15
16use itertools::Itertools;
17use mz_expr::explain::{HumanizedExplain, HumanizerMode};
18use mz_expr::{
19    AggregateExpr, ColumnOrder, Id, JoinImplementation, LocalId, MirRelationExpr, MirScalarExpr,
20    RECURSION_LIMIT, non_nullable_columns,
21};
22use mz_ore::soft_panic_or_log;
23use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
24use mz_repr::adt::range::Range;
25use mz_repr::explain::{DummyHumanizer, ExprHumanizer};
26use mz_repr::{
27    ColumnName, Datum, ReprColumnType, ReprRelationType, ReprScalarBaseType, ReprScalarType,
28};
29
30/// Typechecking contexts as shared by various typechecking passes.
31///
32/// We use a `RefCell` to ensure that contexts are shared by multiple typechecker passes.
33/// Shared contexts help catch consistency issues.
34pub type SharedTypecheckingContext = Arc<Mutex<Context>>;
35
36/// Generates an empty context
37pub fn empty_typechecking_context() -> SharedTypecheckingContext {
38    Arc::new(Mutex::new(BTreeMap::new()))
39}
40
41/// The possible forms of inconsistency/errors discovered during typechecking.
42///
43/// Every variant has a `source` field identifying the MIR term that is home
44/// to the error (though not necessarily the root cause of the error).
45#[derive(Clone, Debug)]
46pub enum TypeError<'a> {
47    /// Unbound identifiers (local or global)
48    Unbound {
49        /// Expression with the bug
50        source: &'a MirRelationExpr,
51        /// The (unbound) identifier referenced
52        id: Id,
53        /// The type `id` was expected to have
54        typ: ReprRelationType,
55    },
56    /// Dereference of a non-existent column
57    NoSuchColumn {
58        /// Expression with the bug
59        source: &'a MirRelationExpr,
60        /// Scalar expression that references an invalid column
61        expr: &'a MirScalarExpr,
62        /// The invalid column referenced
63        col: usize,
64    },
65    /// A single column type does not match
66    MismatchColumn {
67        /// Expression with the bug
68        source: &'a MirRelationExpr,
69        /// The column type we found (`sub` type)
70        got: ReprColumnType,
71        /// The column type we expected (`sup` type)
72        expected: ReprColumnType,
73        /// The difference between these types
74        diffs: Vec<ReprColumnTypeDifference>,
75        /// An explanatory message
76        message: String,
77    },
78    /// Relation column types do not match
79    MismatchColumns {
80        /// Expression with the bug
81        source: &'a MirRelationExpr,
82        /// The column types we found (`sub` type)
83        got: Vec<ReprColumnType>,
84        /// The column types we expected (`sup` type)
85        expected: Vec<ReprColumnType>,
86        /// The difference between these types
87        diffs: Vec<ReprRelationTypeDifference>,
88        /// An explanatory message
89        message: String,
90    },
91    /// A constant row does not have the correct type
92    BadConstantRowLen {
93        /// Expression with the bug
94        source: &'a MirRelationExpr,
95        /// A constant row
96        got: usize,
97        /// The expected type (which that row does not have)
98        expected: Vec<ReprColumnType>,
99    },
100    /// A constant row does not have the correct type
101    BadConstantRow {
102        /// Expression with the bug
103        source: &'a MirRelationExpr,
104        /// The columns that mismatched, with a debug rendering of the datum we got and the expected type
105        mismatches: Vec<(usize, DatumTypeDifference)>,
106        /// The expected type (which that row does not have)
107        expected: Vec<ReprColumnType>,
108        // TODO(mgree) with a good way to get the type of a Datum, we could give a detailed diff here
109        // this is difficult for empty structures where we may not know the element type
110    },
111    /// Projection of a non-existent column
112    BadProject {
113        /// Expression with the bug
114        source: &'a MirRelationExpr,
115        /// The column projected
116        got: Vec<usize>,
117        /// The input columns (which don't have that column)
118        input_type: Vec<ReprColumnType>,
119    },
120    /// An equivalence class in a join was malformed
121    BadJoinEquivalence {
122        /// Expression with the bug
123        source: &'a MirRelationExpr,
124        /// The join equivalences
125        got: Vec<ReprColumnType>,
126        /// The problem with the join equivalences
127        message: String,
128    },
129    /// TopK grouping by non-existent column
130    BadTopKGroupKey {
131        /// Expression with the bug
132        source: &'a MirRelationExpr,
133        /// The bad column reference in the group key
134        k: usize,
135        /// The input columns (which don't have that column)
136        input_type: Vec<ReprColumnType>,
137    },
138    /// TopK ordering by non-existent column
139    BadTopKOrdering {
140        /// Expression with the bug
141        source: &'a MirRelationExpr,
142        /// The ordering used
143        order: ColumnOrder,
144        /// The input columns (which don't work for that ordering)
145        input_type: Vec<ReprColumnType>,
146    },
147    /// LetRec bindings are malformed
148    BadLetRecBindings {
149        /// Expression with the bug
150        source: &'a MirRelationExpr,
151    },
152    /// Local identifiers are shadowed
153    Shadowing {
154        /// Expression with the bug
155        source: &'a MirRelationExpr,
156        /// The id that was shadowed
157        id: Id,
158    },
159    /// Recursion depth exceeded
160    Recursion {
161        /// The error that aborted recursion
162        error: RecursionLimitError,
163    },
164    /// A dummy value was found
165    DisallowedDummy {
166        /// The expression with the dummy value
167        source: &'a MirRelationExpr,
168    },
169}
170
171impl<'a> From<RecursionLimitError> for TypeError<'a> {
172    fn from(error: RecursionLimitError) -> Self {
173        TypeError::Recursion { error }
174    }
175}
176
177type Context = BTreeMap<Id, Vec<ReprColumnType>>;
178
179/// Characterizes differences between relation types
180///
181/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
182#[derive(Clone, Debug, Hash)]
183pub enum ReprRelationTypeDifference {
184    /// `sub` and `sup` don't have the same number of columns
185    Length {
186        /// Length of `sub`
187        len_sub: usize,
188        /// Length of `sup`
189        len_sup: usize,
190    },
191    /// `sub` and `sup` differ at the indicated column
192    Column {
193        /// The column at which `sub` and `sup` differ
194        col: usize,
195        /// The difference between `sub` and `sup`
196        diff: ReprColumnTypeDifference,
197    },
198}
199
200/// Characterizes differences between individual column types
201///
202/// Each constructor indicates a reason why some type `sub` was not a subtype of another type `sup`
203/// There may be multiple reasons, e.g., `sub` may be missing fields and have fields of different types
204#[derive(Clone, Debug, Hash)]
205pub enum ReprColumnTypeDifference {
206    /// The `ReprScalarBaseType` of `sub` doesn't match that of `sup`
207    NotSubtype {
208        /// Would-be subtype
209        sub: ReprScalarType,
210        /// Would-be supertype
211        sup: ReprScalarType,
212    },
213    /// `sub` was nullable but `sup` was not
214    Nullability {
215        /// Would-be subtype
216        sub: ReprColumnType,
217        /// Would-be supertype
218        sup: ReprColumnType,
219    },
220    /// Both `sub` and `sup` are a list, map, array, or range, but `sub`'s element type differed from `sup`s
221    ElementType {
222        /// The type constructor (list, array, etc.)
223        ctor: String,
224        /// The difference in the element type
225        element_type: Box<ReprColumnTypeDifference>,
226    },
227    /// `sub` and `sup` are both records, but `sub` is missing fields present in `sup`
228    RecordMissingFields {
229        /// The missing fields
230        missing: Vec<ColumnName>,
231    },
232    /// `sub` and `sup` are both records, but some fields in `sub` are not subtypes of fields in `sup`
233    RecordFields {
234        /// The differences, by field
235        fields: Vec<ReprColumnTypeDifference>,
236    },
237}
238
239impl ReprRelationTypeDifference {
240    /// Returns the same type difference, but ignoring nullability
241    ///
242    /// Returns `None` when _all_ of the differences are due to nullability
243    pub fn ignore_nullability(self) -> Option<Self> {
244        use ReprRelationTypeDifference::*;
245
246        match self {
247            Length { .. } => Some(self),
248            Column { col, diff } => diff.ignore_nullability().map(|diff| Column { col, diff }),
249        }
250    }
251}
252
253impl ReprColumnTypeDifference {
254    /// Returns the same type difference, but ignoring nullability
255    ///
256    /// Returns `None` when _all_ of the differences are due to nullability
257    pub fn ignore_nullability(self) -> Option<Self> {
258        use ReprColumnTypeDifference::*;
259
260        match self {
261            Nullability { .. } => None,
262            NotSubtype { .. } | RecordMissingFields { .. } => Some(self),
263            ElementType { ctor, element_type } => {
264                element_type
265                    .ignore_nullability()
266                    .map(|element_type| ElementType {
267                        ctor,
268                        element_type: Box::new(element_type),
269                    })
270            }
271            RecordFields { fields } => {
272                let fields = fields
273                    .into_iter()
274                    .flat_map(|diff| diff.ignore_nullability())
275                    .collect::<Vec<_>>();
276
277                if fields.is_empty() {
278                    None
279                } else {
280                    Some(RecordFields { fields })
281                }
282            }
283        }
284    }
285}
286
287/// Returns a list of differences that make `sub` not a subtype of `sup`
288///
289/// This function returns an empty list when `sub` is a subtype of `sup`
290pub fn relation_subtype_difference(
291    sub: &[ReprColumnType],
292    sup: &[ReprColumnType],
293) -> Vec<ReprRelationTypeDifference> {
294    let mut diffs = Vec::new();
295
296    if sub.len() != sup.len() {
297        diffs.push(ReprRelationTypeDifference::Length {
298            len_sub: sub.len(),
299            len_sup: sup.len(),
300        });
301
302        // TODO(mgree) we could do an edit-distance computation to report more errors
303        return diffs;
304    }
305
306    diffs.extend(
307        sub.iter()
308            .zip_eq(sup.iter())
309            .enumerate()
310            .flat_map(|(col, (sub_ty, sup_ty))| {
311                column_subtype_difference(sub_ty, sup_ty)
312                    .into_iter()
313                    .map(move |diff| ReprRelationTypeDifference::Column { col, diff })
314            }),
315    );
316
317    diffs
318}
319
320/// Returns a list of differences that make `sub` not a subtype of `sup`
321///
322/// This function returns an empty list when `sub` is a subtype of `sup`
323pub fn column_subtype_difference(
324    sub: &ReprColumnType,
325    sup: &ReprColumnType,
326) -> Vec<ReprColumnTypeDifference> {
327    let mut diffs = scalar_subtype_difference(&sub.scalar_type, &sup.scalar_type);
328
329    if sub.nullable && !sup.nullable {
330        diffs.push(ReprColumnTypeDifference::Nullability {
331            sub: sub.clone(),
332            sup: sup.clone(),
333        });
334    }
335
336    diffs
337}
338
339/// Returns a list of differences that make `sub` not a subtype of `sup`
340///
341/// This function returns an empty list when `sub` is a subtype of `sup`
342pub fn scalar_subtype_difference(
343    sub: &ReprScalarType,
344    sup: &ReprScalarType,
345) -> Vec<ReprColumnTypeDifference> {
346    use ReprScalarType::*;
347
348    let mut diffs = Vec::new();
349
350    match (sub, sup) {
351        (
352            List {
353                element_type: sub_elt,
354                ..
355            },
356            List {
357                element_type: sup_elt,
358                ..
359            },
360        )
361        | (
362            Map {
363                value_type: sub_elt,
364                ..
365            },
366            Map {
367                value_type: sup_elt,
368                ..
369            },
370        )
371        | (
372            Range {
373                element_type: sub_elt,
374                ..
375            },
376            Range {
377                element_type: sup_elt,
378                ..
379            },
380        )
381        | (Array(sub_elt), Array(sup_elt)) => {
382            let ctor = format!("{:?}", ReprScalarBaseType::from(sub));
383            diffs.extend(
384                scalar_subtype_difference(sub_elt, sup_elt)
385                    .into_iter()
386                    .map(|diff| ReprColumnTypeDifference::ElementType {
387                        ctor: ctor.clone(),
388                        element_type: Box::new(diff),
389                    }),
390            );
391        }
392        (
393            Record {
394                fields: sub_fields, ..
395            },
396            Record {
397                fields: sup_fields, ..
398            },
399        ) => {
400            if sub_fields.len() != sup_fields.len() {
401                diffs.push(ReprColumnTypeDifference::NotSubtype {
402                    sub: sub.clone(),
403                    sup: sup.clone(),
404                });
405                return diffs;
406            }
407
408            for (sub_ty, sup_ty) in sub_fields.iter().zip_eq(sup_fields.iter()) {
409                diffs.extend(column_subtype_difference(sub_ty, sup_ty));
410            }
411        }
412        (_, _) => {
413            if ReprScalarBaseType::from(sub) != ReprScalarBaseType::from(sup) {
414                diffs.push(ReprColumnTypeDifference::NotSubtype {
415                    sub: sub.clone(),
416                    sup: sup.clone(),
417                })
418            }
419        }
420    };
421
422    diffs
423}
424
425/// Unions `other` into `typ`, returning a list of differences on failure
426///
427/// This function returns an empty list when `typ` and `other` are a union
428pub fn scalar_union(
429    typ: &mut ReprScalarType,
430    other: &ReprScalarType,
431) -> Vec<ReprColumnTypeDifference> {
432    use ReprScalarType::*;
433
434    let mut diffs = Vec::new();
435
436    // precomputing to appease the borrow checker
437    let ctor = ReprScalarBaseType::from(&*typ);
438    match (typ, other) {
439        (
440            List {
441                element_type: typ_elt,
442            },
443            List {
444                element_type: other_elt,
445            },
446        )
447        | (
448            Map {
449                value_type: typ_elt,
450            },
451            Map {
452                value_type: other_elt,
453            },
454        )
455        | (
456            Range {
457                element_type: typ_elt,
458            },
459            Range {
460                element_type: other_elt,
461            },
462        )
463        | (Array(typ_elt), Array(other_elt)) => {
464            let res = scalar_union(typ_elt.as_mut(), other_elt.as_ref());
465            diffs.extend(
466                res.into_iter()
467                    .map(|diff| ReprColumnTypeDifference::ElementType {
468                        ctor: format!("{ctor:?}"),
469                        element_type: Box::new(diff),
470                    }),
471            );
472        }
473        (
474            Record { fields: typ_fields },
475            Record {
476                fields: other_fields,
477            },
478        ) => {
479            if typ_fields.len() != other_fields.len() {
480                diffs.push(ReprColumnTypeDifference::NotSubtype {
481                    sub: ReprScalarType::Record {
482                        fields: typ_fields.clone(),
483                    },
484                    sup: other.clone(),
485                });
486                return diffs;
487            }
488
489            for (typ_ty, other_ty) in typ_fields.iter_mut().zip_eq(other_fields.iter()) {
490                diffs.extend(column_union(typ_ty, other_ty));
491            }
492        }
493        (typ, _) => {
494            if ctor != ReprScalarBaseType::from(other) {
495                diffs.push(ReprColumnTypeDifference::NotSubtype {
496                    sub: typ.clone(),
497                    sup: other.clone(),
498                })
499            }
500        }
501    };
502
503    diffs
504}
505
506/// Unions `other` into `typ`, returning a list of differences on failure
507///
508/// This function returns an empty list when `typ` and `other` are a union
509pub fn column_union(
510    typ: &mut ReprColumnType,
511    other: &ReprColumnType,
512) -> Vec<ReprColumnTypeDifference> {
513    let diffs = scalar_union(&mut typ.scalar_type, &other.scalar_type);
514
515    if diffs.is_empty() {
516        typ.nullable |= other.nullable;
517    }
518
519    diffs
520}
521
522/// Returns true when it is safe to treat a `sub` row as an `sup` row
523///
524/// In particular, the core types must be equal, and if a column in `sup` is nullable, that column should also be nullable in `sub`
525/// Conversely, it is okay to treat a known non-nullable column as nullable: `sub` may be nullable when `sup` is not
526pub fn is_subtype_of(sub: &[ReprColumnType], sup: &[ReprColumnType]) -> bool {
527    if sub.len() != sup.len() {
528        return false;
529    }
530
531    sub.iter().zip_eq(sup.iter()).all(|(got, known)| {
532        (!known.nullable || got.nullable) && got.scalar_type == known.scalar_type
533    })
534}
535
536/// Characterizes how a Datum differs from a ReprColumnType
537#[derive(Clone, Debug)]
538pub enum DatumTypeDifference {
539    /// Datum was null, but expected a non-null value (i.e., a ReprScalarType)
540    Null {
541        /// The representation scalar type that we expected
542        expected: ReprScalarType,
543    },
544    /// Datum's kind doesn't match the expected ReprScalarType
545    Mismatch {
546        /// The debug rendering of the Datum that we got (we don't store the Datum itself to avoid borrowing issues)
547        /// This debug rendering includes the DatumKind.
548        got_debug: String,
549        /// The representation scalar type that we expected
550        expected: ReprScalarType,
551    },
552    /// Datum's dimensions didn't match the type's dimensions
553    MismatchDimensions {
554        /// The type constructor (list, array, int2vector, etc.)
555        ctor: String,
556        /// The dimension of the datum that we got
557        got: usize,
558        /// The dimension of the type that we expected
559        expected: usize,
560    },
561    /// There was a nested difference in the element type of some container type
562    ElementType {
563        /// The type constructor (list, array, int2vector, etc.)
564        ctor: String,
565        /// The difference in the element type
566        element_type: Box<DatumTypeDifference>,
567    },
568}
569
570/// Computes the difference between this datum and the specified (representation) column type.
571///
572/// See [`Datum<'a>::is_instance_of`] for just getting a yes/no answer.
573///
574/// See [`Datum<'a>::is_instance_of_sql`] for comparing `Datum`s to `SqlColumnType`s.
575fn datum_difference_with_column_type(
576    datum: &Datum<'_>,
577    column_type: &ReprColumnType,
578) -> Result<(), DatumTypeDifference> {
579    fn difference_with_scalar_type(
580        datum: &Datum<'_>,
581        scalar_type: &ReprScalarType,
582    ) -> Result<(), DatumTypeDifference> {
583        fn mismatch(got: &Datum<'_>, expected: &ReprScalarType) -> Result<(), DatumTypeDifference> {
584            Err(DatumTypeDifference::Mismatch {
585                // this will be redacted as appropriate
586                got_debug: format!("{got:?}"),
587                expected: expected.clone(),
588            })
589        }
590
591        if let ReprScalarType::Jsonb = scalar_type {
592            // json type checking
593            match datum {
594                Datum::Dummy => Ok(()), // allow dummys (unlike is_instance_of)
595                Datum::Null => Err(DatumTypeDifference::Null {
596                    expected: ReprScalarType::Jsonb,
597                }),
598                Datum::JsonNull
599                | Datum::False
600                | Datum::True
601                | Datum::Numeric(_)
602                | Datum::String(_) => Ok(()),
603                Datum::List(list) => {
604                    for elem in list.iter() {
605                        difference_with_scalar_type(&elem, scalar_type)?;
606                    }
607                    Ok(())
608                }
609                Datum::Map(dict) => {
610                    for (_, val) in dict.iter() {
611                        difference_with_scalar_type(&val, scalar_type)?;
612                    }
613                    Ok(())
614                }
615                _ => mismatch(datum, scalar_type),
616            }
617        } else {
618            fn element_type_difference(
619                ctor: &str,
620                element_type: DatumTypeDifference,
621            ) -> DatumTypeDifference {
622                DatumTypeDifference::ElementType {
623                    ctor: ctor.to_string(),
624                    element_type: Box::new(element_type),
625                }
626            }
627            match (datum, scalar_type) {
628                (Datum::Dummy, _) => Ok(()), // allow dummys (unlike is_instance_of)
629                (Datum::Null, _) => Err(DatumTypeDifference::Null {
630                    expected: scalar_type.clone(),
631                }),
632                (Datum::False, ReprScalarType::Bool) => Ok(()),
633                (Datum::False, _) => mismatch(datum, scalar_type),
634                (Datum::True, ReprScalarType::Bool) => Ok(()),
635                (Datum::True, _) => mismatch(datum, scalar_type),
636                (Datum::Int16(_), ReprScalarType::Int16) => Ok(()),
637                (Datum::Int16(_), _) => mismatch(datum, scalar_type),
638                (Datum::Int32(_), ReprScalarType::Int32) => Ok(()),
639                (Datum::Int32(_), _) => mismatch(datum, scalar_type),
640                (Datum::Int64(_), ReprScalarType::Int64) => Ok(()),
641                (Datum::Int64(_), _) => mismatch(datum, scalar_type),
642                (Datum::UInt8(_), ReprScalarType::UInt8) => Ok(()),
643                (Datum::UInt8(_), _) => mismatch(datum, scalar_type),
644                (Datum::UInt16(_), ReprScalarType::UInt16) => Ok(()),
645                (Datum::UInt16(_), _) => mismatch(datum, scalar_type),
646                (Datum::UInt32(_), ReprScalarType::UInt32) => Ok(()),
647                (Datum::UInt32(_), _) => mismatch(datum, scalar_type),
648                (Datum::UInt64(_), ReprScalarType::UInt64) => Ok(()),
649                (Datum::UInt64(_), _) => mismatch(datum, scalar_type),
650                (Datum::Float32(_), ReprScalarType::Float32) => Ok(()),
651                (Datum::Float32(_), _) => mismatch(datum, scalar_type),
652                (Datum::Float64(_), ReprScalarType::Float64) => Ok(()),
653                (Datum::Float64(_), _) => mismatch(datum, scalar_type),
654                (Datum::Date(_), ReprScalarType::Date) => Ok(()),
655                (Datum::Date(_), _) => mismatch(datum, scalar_type),
656                (Datum::Time(_), ReprScalarType::Time) => Ok(()),
657                (Datum::Time(_), _) => mismatch(datum, scalar_type),
658                (Datum::Timestamp(_), ReprScalarType::Timestamp { .. }) => Ok(()),
659                (Datum::Timestamp(_), _) => mismatch(datum, scalar_type),
660                (Datum::TimestampTz(_), ReprScalarType::TimestampTz { .. }) => Ok(()),
661                (Datum::TimestampTz(_), _) => mismatch(datum, scalar_type),
662                (Datum::Interval(_), ReprScalarType::Interval) => Ok(()),
663                (Datum::Interval(_), _) => mismatch(datum, scalar_type),
664                (Datum::Bytes(_), ReprScalarType::Bytes) => Ok(()),
665                (Datum::Bytes(_), _) => mismatch(datum, scalar_type),
666                (Datum::String(_), ReprScalarType::String) => Ok(()),
667                (Datum::String(_), _) => mismatch(datum, scalar_type),
668                (Datum::Uuid(_), ReprScalarType::Uuid) => Ok(()),
669                (Datum::Uuid(_), _) => mismatch(datum, scalar_type),
670                (Datum::Array(array), ReprScalarType::Array(t)) => {
671                    for e in array.elements().iter() {
672                        if let Datum::Null = e {
673                            continue;
674                        }
675
676                        difference_with_scalar_type(&e, t)
677                            .map_err(|e| element_type_difference("array", e))?;
678                    }
679                    Ok(())
680                }
681                (Datum::Array(array), ReprScalarType::Int2Vector) => {
682                    if !array.has_int2vector_dims() {
683                        // Int2Vector is 1-D, but empty arrays use 0 dimensions (PostgreSQL convention,
684                        // the error message just mentions 1 dimension, but 0 dimensions are allowed when empty.
685                        return Err(DatumTypeDifference::MismatchDimensions {
686                            ctor: "int2vector".to_string(),
687                            got: array.dims().len(),
688                            expected: 1,
689                        });
690                    }
691
692                    for e in array.elements().iter() {
693                        difference_with_scalar_type(&e, &ReprScalarType::Int16)
694                            .map_err(|e| element_type_difference("int2vector", e))?;
695                    }
696
697                    Ok(())
698                }
699                (Datum::Array(_), _) => mismatch(datum, scalar_type),
700                (Datum::List(list), ReprScalarType::List { element_type, .. }) => {
701                    for e in list.iter() {
702                        if let Datum::Null = e {
703                            continue;
704                        }
705
706                        difference_with_scalar_type(&e, element_type)
707                            .map_err(|e| element_type_difference("list", e))?;
708                    }
709                    Ok(())
710                }
711                (Datum::List(list), ReprScalarType::Record { fields, .. }) => {
712                    let len = list.iter().count();
713                    if len != fields.len() {
714                        return Err(DatumTypeDifference::MismatchDimensions {
715                            ctor: "record".to_string(),
716                            got: len,
717                            expected: fields.len(),
718                        });
719                    }
720
721                    for (e, t) in list.iter().zip_eq(fields) {
722                        if let Datum::Null = e {
723                            if t.nullable {
724                                continue;
725                            } else {
726                                return Err(DatumTypeDifference::Null {
727                                    expected: t.scalar_type.clone(),
728                                });
729                            }
730                        }
731
732                        difference_with_scalar_type(&e, &t.scalar_type)
733                            .map_err(|e| element_type_difference("record", e))?;
734                    }
735                    Ok(())
736                }
737                (Datum::List(_), _) => mismatch(datum, scalar_type),
738                (Datum::Map(map), ReprScalarType::Map { value_type, .. }) => {
739                    for (_, v) in map.iter() {
740                        if let Datum::Null = v {
741                            continue;
742                        }
743
744                        difference_with_scalar_type(&v, value_type)
745                            .map_err(|e| element_type_difference("map", e))?;
746                    }
747                    Ok(())
748                }
749                (Datum::Map(_), _) => mismatch(datum, scalar_type),
750                (Datum::JsonNull, _) => mismatch(datum, scalar_type),
751                (Datum::Numeric(_), ReprScalarType::Numeric) => Ok(()),
752                (Datum::Numeric(_), _) => mismatch(datum, scalar_type),
753                (Datum::MzTimestamp(_), ReprScalarType::MzTimestamp) => Ok(()),
754                (Datum::MzTimestamp(_), _) => mismatch(datum, scalar_type),
755                (Datum::Range(Range { inner }), ReprScalarType::Range { element_type }) => {
756                    match inner {
757                        None => Ok(()),
758                        Some(inner) => {
759                            if let Some(b) = inner.lower.bound {
760                                difference_with_scalar_type(&b.datum(), element_type)
761                                    .map_err(|e| element_type_difference("range", e))?;
762                            }
763                            if let Some(b) = inner.upper.bound {
764                                difference_with_scalar_type(&b.datum(), element_type)
765                                    .map_err(|e| element_type_difference("range", e))?;
766                            }
767                            Ok(())
768                        }
769                    }
770                }
771                (Datum::Range(_), _) => mismatch(datum, scalar_type),
772                (Datum::MzAclItem(_), ReprScalarType::MzAclItem) => Ok(()),
773                (Datum::MzAclItem(_), _) => mismatch(datum, scalar_type),
774                (Datum::AclItem(_), ReprScalarType::AclItem) => Ok(()),
775                (Datum::AclItem(_), _) => mismatch(datum, scalar_type),
776            }
777        }
778    }
779    if column_type.nullable {
780        if let Datum::Null = datum {
781            return Ok(());
782        }
783    }
784    difference_with_scalar_type(datum, &column_type.scalar_type)
785}
786
787fn row_difference_with_column_types<'a>(
788    source: &'a MirRelationExpr,
789    datums: &Vec<Datum<'_>>,
790    column_types: &[ReprColumnType],
791) -> Result<(), TypeError<'a>> {
792    // correct length
793    if datums.len() != column_types.len() {
794        return Err(TypeError::BadConstantRowLen {
795            source,
796            got: datums.len(),
797            expected: column_types.to_vec(),
798        });
799    }
800
801    // correct types
802    let mut mismatches = Vec::new();
803    for (i, (d, ty)) in datums.iter().zip_eq(column_types.iter()).enumerate() {
804        if let Err(e) = datum_difference_with_column_type(d, ty) {
805            mismatches.push((i, e));
806        }
807    }
808    if !mismatches.is_empty() {
809        return Err(TypeError::BadConstantRow {
810            source,
811            mismatches,
812            expected: column_types.to_vec(),
813        });
814    }
815
816    Ok(())
817}
818/// Check that the visible type of each query has not been changed
819#[derive(Debug)]
820pub struct Typecheck {
821    /// The known types of the queries so far
822    ctx: SharedTypecheckingContext,
823    /// Whether or not this is the first run of the transform
824    disallow_new_globals: bool,
825    /// Whether or not to be strict about join equivalences having the same nullability
826    strict_join_equivalences: bool,
827    /// Whether or not to disallow dummy values
828    disallow_dummy: bool,
829    /// Recursion guard for checked recursion
830    recursion_guard: RecursionGuard,
831}
832
833impl CheckedRecursion for Typecheck {
834    fn recursion_guard(&self) -> &RecursionGuard {
835        &self.recursion_guard
836    }
837}
838
839impl Typecheck {
840    /// Creates a typechecking consistency checking pass using a given shared context
841    pub fn new(ctx: SharedTypecheckingContext) -> Self {
842        Self {
843            ctx,
844            disallow_new_globals: false,
845            strict_join_equivalences: false,
846            disallow_dummy: false,
847            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
848        }
849    }
850
851    /// New non-transient global IDs will be treated as an error
852    ///
853    /// Only turn this on after the context has been appropriately populated by, e.g., an earlier run
854    pub fn disallow_new_globals(mut self) -> Self {
855        self.disallow_new_globals = true;
856        self
857    }
858
859    /// Equivalence classes in joins must not only agree on scalar type, but also on nullability
860    ///
861    /// Only turn this on before `JoinImplementation`
862    pub fn strict_join_equivalences(mut self) -> Self {
863        self.strict_join_equivalences = true;
864
865        self
866    }
867
868    /// Disallow dummy values
869    pub fn disallow_dummy(mut self) -> Self {
870        self.disallow_dummy = true;
871        self
872    }
873
874    /// Returns the type of a relation expression or a type error.
875    ///
876    /// This function is careful to check validity, not just find out the type.
877    ///
878    /// It should be linear in the size of the AST.
879    ///
880    /// ??? should we also compute keys and return a `ReprRelationType`?
881    ///   ggevay: Checking keys would have the same problem as checking nullability: key inference
882    ///   is very heuristic (even more so than nullability inference), so it's almost impossible to
883    ///   reliably keep it stable across transformations.
884    pub fn typecheck<'a>(
885        &self,
886        expr: &'a MirRelationExpr,
887        ctx: &Context,
888    ) -> Result<Vec<ReprColumnType>, TypeError<'a>> {
889        use MirRelationExpr::*;
890
891        self.checked_recur(|tc| match expr {
892            Constant { typ, rows } => {
893                if let Ok(rows) = rows {
894                    for (row, _id) in rows {
895                        let datums = row.unpack();
896
897                        let col_types = typ
898                            .column_types
899                            .iter()
900                            .cloned()
901                            .collect_vec();
902                        row_difference_with_column_types(
903                            expr, &datums, &col_types,
904                        )?;
905
906                        if self.disallow_dummy
907                            && datums.iter().any(|d| d == &mz_repr::Datum::Dummy)
908                        {
909                            return Err(TypeError::DisallowedDummy {
910                                source: expr,
911                            });
912                        }
913                    }
914                }
915
916                Ok(typ.column_types.iter().cloned().collect_vec())
917            }
918            Get { typ, id, .. } => {
919                if let Id::Global(_global_id) = id {
920                    if !ctx.contains_key(id) {
921                        // TODO(mgree) pass QueryContext through to check these types
922                        return Ok(typ.column_types.iter().cloned().collect_vec());
923                    }
924                }
925
926                let ctx_typ = ctx.get(id).ok_or_else(|| TypeError::Unbound {
927                    source: expr,
928                    id: id.clone(),
929                    typ: typ.clone(),
930                })?;
931
932                let column_types = typ.column_types.iter().cloned().collect_vec();
933
934                // covariant: the ascribed type must be a subtype of the actual type in the context
935                let diffs = relation_subtype_difference(&column_types, ctx_typ)
936                    .into_iter()
937                    .flat_map(|diff| diff.ignore_nullability())
938                    .collect::<Vec<_>>();
939
940                if !diffs.is_empty() {
941                    return Err(TypeError::MismatchColumns {
942                        source: expr,
943                        got: column_types,
944                        expected: ctx_typ.clone(),
945                        diffs,
946                        message: "annotation did not match context type".to_string(),
947                    });
948                }
949
950                Ok(column_types)
951            }
952            Project { input, outputs } => {
953                let t_in = tc.typecheck(input, ctx)?;
954
955                for x in outputs {
956                    if *x >= t_in.len() {
957                        return Err(TypeError::BadProject {
958                            source: expr,
959                            got: outputs.clone(),
960                            input_type: t_in,
961                        });
962                    }
963                }
964
965                Ok(outputs.iter().map(|col| t_in[*col].clone()).collect())
966            }
967            Map { input, scalars } => {
968                let mut t_in = tc.typecheck(input, ctx)?;
969
970                for scalar_expr in scalars.iter() {
971                    t_in.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
972
973                    if self.disallow_dummy && scalar_expr.contains_dummy() {
974                        return Err(TypeError::DisallowedDummy {
975                            source: expr,
976                        });
977                    }
978                }
979
980                Ok(t_in)
981            }
982            FlatMap { input, func, exprs } => {
983                let mut t_in = tc.typecheck(input, ctx)?;
984
985                let mut t_exprs = Vec::with_capacity(exprs.len());
986                for scalar_expr in exprs {
987                    t_exprs.push(tc.typecheck_scalar(scalar_expr, expr, &t_in)?);
988
989                    if self.disallow_dummy && scalar_expr.contains_dummy() {
990                        return Err(TypeError::DisallowedDummy {
991                            source: expr,
992                        });
993                    }
994                }
995                // TODO(mgree) check t_exprs agrees with `func`'s input type
996
997                let t_out: Vec<ReprColumnType> = func
998                    .output_type().column_types;
999
1000                // FlatMap extends the existing columns
1001                t_in.extend(t_out);
1002                Ok(t_in)
1003            }
1004            Filter { input, predicates } => {
1005                let mut t_in = tc.typecheck(input, ctx)?;
1006
1007                // Set as nonnull any columns where null values would cause
1008                // any predicate to evaluate to null.
1009                for column in non_nullable_columns(predicates) {
1010                    t_in[column].nullable = false;
1011                }
1012
1013                for scalar_expr in predicates {
1014                    let t = tc.typecheck_scalar(scalar_expr, expr, &t_in)?;
1015
1016                    // filter condition must be boolean
1017                    // ignoring nullability: null is treated as false
1018                    // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1019                    if t.scalar_type != ReprScalarType::Bool {
1020                        let sub = t.scalar_type.clone();
1021
1022                        return Err(TypeError::MismatchColumn {
1023                            source: expr,
1024                            got: t,
1025                            expected: ReprColumnType {
1026                                scalar_type: ReprScalarType::Bool,
1027                                nullable: true,
1028                            },
1029                            diffs: vec![ReprColumnTypeDifference::NotSubtype {
1030                                sub,
1031                                sup: ReprScalarType::Bool,
1032                            }],
1033                            message: "expected boolean condition".to_string(),
1034                        });
1035                    }
1036
1037                    if self.disallow_dummy && scalar_expr.contains_dummy() {
1038                        return Err(TypeError::DisallowedDummy {
1039                            source: expr,
1040                        });
1041                    }
1042                }
1043
1044                Ok(t_in)
1045            }
1046            Join {
1047                inputs,
1048                equivalences,
1049                implementation,
1050            } => {
1051                let mut t_in_global = Vec::new();
1052                let mut t_in_local = vec![Vec::new(); inputs.len()];
1053
1054                for (i, input) in inputs.iter().enumerate() {
1055                    let input_t = tc.typecheck(input, ctx)?;
1056                    t_in_global.extend(input_t.clone());
1057                    t_in_local[i] = input_t;
1058                }
1059
1060                for eq_class in equivalences {
1061                    let mut t_exprs: Vec<ReprColumnType> = Vec::with_capacity(eq_class.len());
1062
1063                    let mut all_nullable = true;
1064
1065                    for scalar_expr in eq_class {
1066                        // Note: the equivalences have global column references
1067                        let t_expr = tc.typecheck_scalar(scalar_expr, expr, &t_in_global)?;
1068
1069                        if !t_expr.nullable {
1070                            all_nullable = false;
1071                        }
1072
1073                        if let Some(t_first) = t_exprs.get(0) {
1074                            let diffs = scalar_subtype_difference(
1075                                &t_expr.scalar_type,
1076                                &t_first.scalar_type,
1077                            );
1078                            if !diffs.is_empty() {
1079                                return Err(TypeError::MismatchColumn {
1080                                    source: expr,
1081                                    got: t_expr,
1082                                    expected: t_first.clone(),
1083                                    diffs,
1084                                    message: "equivalence class members \
1085                                        have different scalar types"
1086                                        .to_string(),
1087                                });
1088                            }
1089
1090                            // equivalences may or may not match on nullability
1091                            // before JoinImplementation runs, nullability should match.
1092                            // but afterwards, some nulls may appear that are actually being filtered out elsewhere
1093                            if self.strict_join_equivalences {
1094                                if t_expr.nullable != t_first.nullable {
1095                                    let sub = t_expr.clone();
1096                                    let sup = t_first.clone();
1097
1098                                    let err = TypeError::MismatchColumn {
1099                                        source: expr,
1100                                        got: t_expr.clone(),
1101                                        expected: t_first.clone(),
1102                                        diffs: vec![
1103                                            ReprColumnTypeDifference::Nullability { sub, sup },
1104                                        ],
1105                                        message: "equivalence class members have \
1106                                            different nullability (and join \
1107                                            equivalence checking is strict)"
1108                                            .to_string(),
1109                                    };
1110
1111                                    // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1112                                    ::tracing::debug!("{err}");
1113                                }
1114                            }
1115                        }
1116
1117                        if self.disallow_dummy && scalar_expr.contains_dummy() {
1118                            return Err(TypeError::DisallowedDummy {
1119                                source: expr,
1120                            });
1121                        }
1122
1123                        t_exprs.push(t_expr);
1124                    }
1125
1126                    if self.strict_join_equivalences && all_nullable {
1127                        let err = TypeError::BadJoinEquivalence {
1128                            source: expr,
1129                            got: t_exprs,
1130                            message: "all expressions were nullable (and join equivalence checking is strict)".to_string(),
1131                        };
1132
1133                        // TODO(mgree) this imprecision should be resolved, but we need to fix the optimizer
1134                        ::tracing::debug!("{err}");
1135                    }
1136                }
1137
1138                // check that the join implementation is consistent
1139                match implementation {
1140                    JoinImplementation::Differential((start_idx, first_key, _), others) => {
1141                        if let Some(key) = first_key {
1142                            for k in key {
1143                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*start_idx])?;
1144                            }
1145                        }
1146
1147                        for (idx, key, _) in others {
1148                            for k in key {
1149                                let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1150                            }
1151                        }
1152                    }
1153                    JoinImplementation::DeltaQuery(plans) => {
1154                        for plan in plans {
1155                            for (idx, key, _) in plan {
1156                                for k in key {
1157                                    let _ = tc.typecheck_scalar(k, expr, &t_in_local[*idx])?;
1158                                }
1159                            }
1160                        }
1161                    }
1162                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, key, consts) => {
1163                        let typ: Vec<ReprColumnType> = key
1164                            .iter()
1165                            .map(|k| tc.typecheck_scalar(k, expr, &t_in_global))
1166                            .collect::<Result<Vec<ReprColumnType>, TypeError>>()?;
1167
1168                        for row in consts {
1169                            let datums = row.unpack();
1170
1171                            row_difference_with_column_types(expr, &datums, &typ)?;
1172                        }
1173                    }
1174                    JoinImplementation::Unimplemented => (),
1175                }
1176
1177                Ok(t_in_global)
1178            }
1179            Reduce {
1180                input,
1181                group_key,
1182                aggregates,
1183                monotonic: _,
1184                expected_group_size: _,
1185            } => {
1186                let t_in = tc.typecheck(input, ctx)?;
1187
1188                let mut t_out = group_key
1189                    .iter()
1190                    .map(|scalar_expr| tc.typecheck_scalar(scalar_expr, expr, &t_in))
1191                    .collect::<Result<Vec<_>, _>>()?;
1192
1193                    if self.disallow_dummy
1194                        && group_key
1195                            .iter()
1196                            .any(|scalar_expr| scalar_expr.contains_dummy())
1197                    {
1198                        return Err(TypeError::DisallowedDummy {
1199                            source: expr,
1200                        });
1201                    }
1202
1203                for agg in aggregates {
1204                    t_out.push(tc.typecheck_aggregate(agg, expr, &t_in)?);
1205                }
1206
1207                Ok(t_out)
1208            }
1209            TopK {
1210                input,
1211                group_key,
1212                order_key,
1213                limit: _,
1214                offset: _,
1215                monotonic: _,
1216                expected_group_size: _,
1217            } => {
1218                let t_in = tc.typecheck(input, ctx)?;
1219
1220                for &k in group_key {
1221                    if k >= t_in.len() {
1222                        return Err(TypeError::BadTopKGroupKey {
1223                            source: expr,
1224                            k,
1225                            input_type: t_in,
1226                        });
1227                    }
1228                }
1229
1230                for order in order_key {
1231                    if order.column >= t_in.len() {
1232                        return Err(TypeError::BadTopKOrdering {
1233                            source: expr,
1234                            order: order.clone(),
1235                            input_type: t_in,
1236                        });
1237                    }
1238                }
1239
1240                Ok(t_in)
1241            }
1242            Negate { input } => tc.typecheck(input, ctx),
1243            Threshold { input } => tc.typecheck(input, ctx),
1244            Union { base, inputs } => {
1245                let mut t_base = tc.typecheck(base, ctx)?;
1246
1247                for input in inputs {
1248                    let t_input = tc.typecheck(input, ctx)?;
1249
1250                    let len_sub = t_base.len();
1251                    let len_sup = t_input.len();
1252                    if len_sub != len_sup {
1253                        return Err(TypeError::MismatchColumns {
1254                            source: expr,
1255                            got: t_base.clone(),
1256                            expected: t_input,
1257                            diffs: vec![ReprRelationTypeDifference::Length {
1258                                len_sub,
1259                                len_sup,
1260                            }],
1261                            message: "Union branches have different numbers of columns".to_string(),
1262                        });
1263                    }
1264
1265                    for (base_col, input_col) in t_base.iter_mut().zip_eq(t_input) {
1266                        let diffs = column_union(base_col, &input_col);
1267                        if !diffs.is_empty() {
1268                            return Err(TypeError::MismatchColumn {
1269                                    source: expr,
1270                                    got: input_col,
1271                                    expected: base_col.clone(),
1272                                    diffs,
1273                                    message:
1274                                        "couldn't compute union of column types in Union"
1275                                    .to_string(),
1276                            });
1277                        }
1278
1279                    }
1280                }
1281
1282                Ok(t_base)
1283            }
1284            Let { id, value, body } => {
1285                let t_value = tc.typecheck(value, ctx)?;
1286
1287                let binding = Id::Local(*id);
1288                if ctx.contains_key(&binding) {
1289                    return Err(TypeError::Shadowing {
1290                        source: expr,
1291                        id: binding,
1292                    });
1293                }
1294
1295                let mut body_ctx = ctx.clone();
1296                body_ctx.insert(Id::Local(*id), t_value);
1297
1298                tc.typecheck(body, &body_ctx)
1299            }
1300            LetRec { ids, values, body, limits: _ } => {
1301                if ids.len() != values.len() {
1302                    return Err(TypeError::BadLetRecBindings { source: expr });
1303                }
1304
1305                // temporary hack: steal info from the Gets inside to learn the expected types
1306                // if no get occurs in any definition or the body, that means that relation is dead code (which is okay)
1307                let mut ctx = ctx.clone();
1308                // calling tc.collect_recursive_variable_types(expr, ...) triggers a panic due to nested letrecs with shadowing IDs
1309                for inner_expr in values.iter().chain(std::iter::once(body.as_ref())) {
1310                    tc.collect_recursive_variable_types(inner_expr, ids, &mut ctx)?;
1311                }
1312
1313                for (id, value) in ids.iter().zip_eq(values.iter()) {
1314                    let typ = tc.typecheck(value, &ctx)?;
1315
1316                    let id = Id::Local(id.clone());
1317                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1318                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1319                            // we expect an EXACT match, but don't care about nullability
1320                            let diffs = column_union(base_col, &input_col);
1321                            if !diffs.is_empty() {
1322                                 return Err(TypeError::MismatchColumn {
1323                                        source: expr,
1324                                        got: input_col,
1325                                        expected: base_col.clone(),
1326                                        diffs,
1327                                        message:
1328                                            "couldn't compute union of column types in LetRec"
1329                                        .to_string(),
1330                                    })
1331                            }
1332                        }
1333                    } else {
1334                        // dead code: no `Get` references this relation anywhere. we record the type anyway
1335                        ctx.insert(id, typ);
1336                    }
1337                }
1338
1339                tc.typecheck(body, &ctx)
1340            }
1341            ArrangeBy { input, keys } => {
1342                let t_in = tc.typecheck(input, ctx)?;
1343
1344                for key in keys {
1345                    for k in key {
1346                        let _ = tc.typecheck_scalar(k, expr, &t_in)?;
1347                    }
1348                }
1349
1350                Ok(t_in)
1351            }
1352        })
1353    }
1354
1355    /// Traverses a term to collect the types of given ids.
1356    ///
1357    /// LetRec doesn't have type info stored in it. Until we change the MIR to track that information explicitly, we have to rebuild it from looking at the term.
1358    fn collect_recursive_variable_types<'a>(
1359        &self,
1360        expr: &'a MirRelationExpr,
1361        ids: &[LocalId],
1362        ctx: &mut Context,
1363    ) -> Result<(), TypeError<'a>> {
1364        use MirRelationExpr::*;
1365
1366        self.checked_recur(|tc| {
1367            match expr {
1368                Get {
1369                    id: Id::Local(id),
1370                    typ,
1371                    ..
1372                } => {
1373                    if !ids.contains(id) {
1374                        return Ok(());
1375                    }
1376
1377                    let id = Id::Local(id.clone());
1378                    if let Some(ctx_typ) = ctx.get_mut(&id) {
1379                        let typ = typ.column_types.iter().cloned().collect_vec();
1380
1381                        if ctx_typ.len() != typ.len() {
1382                            let diffs = relation_subtype_difference(&typ, ctx_typ);
1383
1384                            return Err(TypeError::MismatchColumns {
1385                                source: expr,
1386                                got: typ,
1387                                expected: ctx_typ.clone(),
1388                                diffs,
1389                                message: "environment and type annotation did not match"
1390                                    .to_string(),
1391                            });
1392                        }
1393
1394                        for (base_col, input_col) in ctx_typ.iter_mut().zip_eq(typ) {
1395                            let diffs = column_union(base_col, &input_col);
1396                            if !diffs.is_empty() {
1397                                return Err(TypeError::MismatchColumn {
1398                                    source: expr,
1399                                    got: input_col,
1400                                    expected: base_col.clone(),
1401                                    diffs,
1402                                    message:
1403                                        "couldn't compute union of column types in Get and context"
1404                                            .to_string(),
1405                                });
1406                            }
1407                        }
1408                    } else {
1409                        ctx.insert(id, typ.column_types.iter().cloned().collect_vec());
1410                    }
1411                }
1412                Get {
1413                    id: Id::Global(..), ..
1414                }
1415                | Constant { .. } => (),
1416                Let { id, value, body } => {
1417                    tc.collect_recursive_variable_types(value, ids, ctx)?;
1418
1419                    // we've shadowed the id
1420                    if ids.contains(id) {
1421                        return Err(TypeError::Shadowing {
1422                            source: expr,
1423                            id: Id::Local(*id),
1424                        });
1425                    }
1426
1427                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1428                }
1429                LetRec {
1430                    ids: inner_ids,
1431                    values,
1432                    body,
1433                    limits: _,
1434                } => {
1435                    for inner_id in inner_ids {
1436                        if ids.contains(inner_id) {
1437                            return Err(TypeError::Shadowing {
1438                                source: expr,
1439                                id: Id::Local(*inner_id),
1440                            });
1441                        }
1442                    }
1443
1444                    for value in values {
1445                        tc.collect_recursive_variable_types(value, ids, ctx)?;
1446                    }
1447
1448                    tc.collect_recursive_variable_types(body, ids, ctx)?;
1449                }
1450                Project { input, .. }
1451                | Map { input, .. }
1452                | FlatMap { input, .. }
1453                | Filter { input, .. }
1454                | Reduce { input, .. }
1455                | TopK { input, .. }
1456                | Negate { input }
1457                | Threshold { input }
1458                | ArrangeBy { input, .. } => {
1459                    tc.collect_recursive_variable_types(input, ids, ctx)?;
1460                }
1461                Join { inputs, .. } => {
1462                    for input in inputs {
1463                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1464                    }
1465                }
1466                Union { base, inputs } => {
1467                    tc.collect_recursive_variable_types(base, ids, ctx)?;
1468
1469                    for input in inputs {
1470                        tc.collect_recursive_variable_types(input, ids, ctx)?;
1471                    }
1472                }
1473            }
1474
1475            Ok(())
1476        })
1477    }
1478
1479    fn typecheck_scalar<'a>(
1480        &self,
1481        expr: &'a MirScalarExpr,
1482        source: &'a MirRelationExpr,
1483        column_types: &[ReprColumnType],
1484    ) -> Result<ReprColumnType, TypeError<'a>> {
1485        use MirScalarExpr::*;
1486
1487        self.checked_recur(|tc| match expr {
1488            Column(i, _) => match column_types.get(*i) {
1489                Some(ty) => Ok(ty.clone()),
1490                None => Err(TypeError::NoSuchColumn {
1491                    source,
1492                    expr,
1493                    col: *i,
1494                }),
1495            },
1496            Literal(row, typ) => {
1497                let typ = typ.clone();
1498                if let Ok(row) = row {
1499                    let datums = row.unpack();
1500
1501                    row_difference_with_column_types(source, &datums, std::slice::from_ref(&typ))?;
1502                }
1503
1504                Ok(typ)
1505            }
1506            CallUnmaterializable(func) => Ok(func.output_type()),
1507            CallUnary { expr, func } => {
1508                let typ_in = tc.typecheck_scalar(expr, source, column_types)?;
1509                let typ_out = func.output_type(typ_in);
1510                Ok(typ_out)
1511            }
1512            CallBinary { expr1, expr2, func } => {
1513                let typ_in1 = tc.typecheck_scalar(expr1, source, column_types)?;
1514                let typ_in2 = tc.typecheck_scalar(expr2, source, column_types)?;
1515                let typ_out = func.output_type(&[typ_in1, typ_in2]);
1516                Ok(typ_out)
1517            }
1518            CallVariadic { exprs, func } => Ok(func.output_type(
1519                exprs
1520                    .iter()
1521                    .map(|e| tc.typecheck_scalar(e, source, column_types))
1522                    .collect::<Result<Vec<_>, TypeError>>()?,
1523            )),
1524            If { cond, then, els } => {
1525                let cond_type = tc.typecheck_scalar(cond, source, column_types)?;
1526
1527                // condition must be boolean
1528                // ignoring nullability: null is treated as false
1529                // NB this behavior is slightly different from columns_match (for which we would set nullable to false in the expected type)
1530                if cond_type.scalar_type != ReprScalarType::Bool {
1531                    let sub = cond_type.scalar_type.clone();
1532
1533                    return Err(TypeError::MismatchColumn {
1534                        source,
1535                        got: cond_type,
1536                        expected: ReprColumnType {
1537                            scalar_type: ReprScalarType::Bool,
1538                            nullable: true,
1539                        },
1540                        diffs: vec![ReprColumnTypeDifference::NotSubtype {
1541                            sub,
1542                            sup: ReprScalarType::Bool,
1543                        }],
1544                        message: "expected boolean condition".to_string(),
1545                    });
1546                }
1547
1548                let mut then_type = tc.typecheck_scalar(then, source, column_types)?;
1549                let else_type = tc.typecheck_scalar(els, source, column_types)?;
1550
1551                let diffs = column_union(&mut then_type, &else_type);
1552                if !diffs.is_empty() {
1553                    return Err(TypeError::MismatchColumn {
1554                        source,
1555                        got: then_type,
1556                        expected: else_type,
1557                        diffs,
1558                        message: "couldn't compute union of column types for If".to_string(),
1559                    });
1560                }
1561
1562                Ok(then_type)
1563            }
1564        })
1565    }
1566
1567    /// Typecheck an `AggregateExpr`
1568    pub fn typecheck_aggregate<'a>(
1569        &self,
1570        expr: &'a AggregateExpr,
1571        source: &'a MirRelationExpr,
1572        column_types: &[ReprColumnType],
1573    ) -> Result<ReprColumnType, TypeError<'a>> {
1574        self.checked_recur(|tc| {
1575            let t_in = tc.typecheck_scalar(&expr.expr, source, column_types)?;
1576
1577            // TODO check that t_in is actually acceptable for `func`
1578
1579            Ok(expr.func.output_type(t_in))
1580        })
1581    }
1582}
1583
1584/// Detailed type error logging as a warning, with failures in CI and a logged error in production
1585///
1586/// type_error(severity, ...) logs a type warning; if `severity` is `true`, it will also log an error (visible in Sentry)
1587macro_rules! type_error {
1588    ($severity:expr, $($arg:tt)+) => {{
1589        if $severity {
1590          soft_panic_or_log!($($arg)+);
1591        } else {
1592          ::tracing::debug!($($arg)+);
1593        }
1594    }}
1595}
1596
1597impl crate::Transform for Typecheck {
1598    fn name(&self) -> &'static str {
1599        "Typecheck"
1600    }
1601
1602    fn actually_perform_transform(
1603        &self,
1604        relation: &mut MirRelationExpr,
1605        transform_ctx: &mut crate::TransformCtx,
1606    ) -> Result<(), crate::TransformError> {
1607        let mut typecheck_ctx = self.ctx.lock().expect("typecheck ctx");
1608
1609        let expected = transform_ctx
1610            .global_id
1611            .map_or_else(|| None, |id| typecheck_ctx.get(&Id::Global(id)));
1612
1613        if let Some(id) = transform_ctx.global_id {
1614            if self.disallow_new_globals
1615                && expected.is_none()
1616                && transform_ctx.global_id.is_some()
1617                && !id.is_transient()
1618            {
1619                type_error!(
1620                    false, // not severe
1621                    "type warning: new non-transient global id {id}\n{}",
1622                    relation.pretty()
1623                );
1624            }
1625        }
1626
1627        let got = self.typecheck(relation, &typecheck_ctx);
1628
1629        let humanizer = mz_repr::explain::DummyHumanizer;
1630
1631        match (got, expected) {
1632            (Ok(got), Some(expected)) => {
1633                let id = transform_ctx.global_id.unwrap();
1634
1635                // contravariant: global types can be updated
1636                let diffs = relation_subtype_difference(expected, &got);
1637                if !diffs.is_empty() {
1638                    // SEVERE only if got and expected have true differences, not just nullability
1639                    let severity = diffs
1640                        .iter()
1641                        .any(|diff| diff.clone().ignore_nullability().is_some());
1642
1643                    let err = TypeError::MismatchColumns {
1644                        source: relation,
1645                        got,
1646                        expected: expected.clone(),
1647                        diffs,
1648                        message: format!(
1649                            "a global id {id}'s type changed (was `expected` which should be a subtype of `got`) "
1650                        ),
1651                    };
1652
1653                    type_error!(severity, "type error in known global id {id}:\n{err}");
1654                }
1655            }
1656            (Ok(got), None) => {
1657                if let Some(id) = transform_ctx.global_id {
1658                    typecheck_ctx.insert(Id::Global(id), got);
1659                }
1660            }
1661            (Err(err), _) => {
1662                let (expected, binding) = match expected {
1663                    Some(expected) => {
1664                        let id = transform_ctx.global_id.unwrap();
1665                        (
1666                            format!("expected type {}\n", columns_pretty(expected, &humanizer)),
1667                            format!("known global id {id}"),
1668                        )
1669                    }
1670                    None => ("".to_string(), "transient query".to_string()),
1671                };
1672
1673                type_error!(
1674                    true, // SEVERE: the transformed code is inconsistent
1675                    "type error in {binding}:\n{err}\n{expected}{}",
1676                    relation.pretty()
1677                );
1678            }
1679        }
1680
1681        Ok(())
1682    }
1683}
1684
1685/// Prints a type prettily with a given `ExprHumanizer`
1686pub fn columns_pretty<H>(cols: &[ReprColumnType], humanizer: &H) -> String
1687where
1688    H: ExprHumanizer,
1689{
1690    let mut s = String::with_capacity(2 + 3 * cols.len());
1691
1692    s.push('(');
1693
1694    let mut it = cols.iter().peekable();
1695    while let Some(col) = it.next() {
1696        s.push_str(&humanizer.humanize_column_type(col, false));
1697
1698        if it.peek().is_some() {
1699            s.push_str(", ");
1700        }
1701    }
1702
1703    s.push(')');
1704
1705    s
1706}
1707
1708impl ReprRelationTypeDifference {
1709    /// Pretty prints a type difference
1710    ///
1711    /// Always indents two spaces
1712    pub fn humanize<H>(&self, h: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1713    where
1714        H: ExprHumanizer,
1715    {
1716        use ReprRelationTypeDifference::*;
1717        match self {
1718            Length { len_sub, len_sup } => {
1719                writeln!(
1720                    f,
1721                    "  number of columns do not match ({len_sub} != {len_sup})"
1722                )
1723            }
1724            Column { col, diff } => {
1725                writeln!(f, "  column {col} differs:")?;
1726                diff.humanize(4, h, f)
1727            }
1728        }
1729    }
1730}
1731
1732impl ReprColumnTypeDifference {
1733    /// Pretty prints a type difference at a given indentation level
1734    pub fn humanize<H>(
1735        &self,
1736        indent: usize,
1737        h: &H,
1738        f: &mut std::fmt::Formatter<'_>,
1739    ) -> std::fmt::Result
1740    where
1741        H: ExprHumanizer,
1742    {
1743        use ReprColumnTypeDifference::*;
1744
1745        // indent
1746        write!(f, "{:indent$}", "")?;
1747
1748        match self {
1749            NotSubtype { sub, sup } => {
1750                let sub = h.humanize_scalar_type(sub, false);
1751                let sup = h.humanize_scalar_type(sup, false);
1752
1753                writeln!(f, "{sub} is a not a subtype of {sup}")
1754            }
1755            Nullability { sub, sup } => {
1756                let sub = h.humanize_column_type(sub, false);
1757                let sup = h.humanize_column_type(sup, false);
1758
1759                writeln!(f, "{sub} is nullable but {sup} is not")
1760            }
1761            ElementType { ctor, element_type } => {
1762                writeln!(f, "{ctor} element types differ:")?;
1763
1764                element_type.humanize(indent + 2, h, f)
1765            }
1766            RecordMissingFields { missing } => {
1767                write!(f, "missing column fields:")?;
1768                for col in missing {
1769                    write!(f, " {col}")?;
1770                }
1771                f.write_char('\n')
1772            }
1773            RecordFields { fields } => {
1774                writeln!(f, "{} record fields differ:", fields.len())?;
1775
1776                for (i, diff) in fields.iter().enumerate() {
1777                    writeln!(f, "{:indent$}  field {i}:", "")?;
1778                    diff.humanize(indent + 4, h, f)?;
1779                }
1780                Ok(())
1781            }
1782        }
1783    }
1784}
1785
1786impl DatumTypeDifference {
1787    /// Pretty prints a type difference at a given indentation level
1788    pub fn humanize<H>(
1789        &self,
1790        indent: usize,
1791        h: &H,
1792        f: &mut std::fmt::Formatter<'_>,
1793    ) -> std::fmt::Result
1794    where
1795        H: ExprHumanizer,
1796    {
1797        // indent
1798        write!(f, "{:indent$}", "")?;
1799
1800        match self {
1801            DatumTypeDifference::Null { expected } => {
1802                let expected = h.humanize_scalar_type(expected, false);
1803                writeln!(
1804                    f,
1805                    "unexpected null, expected representation type {expected}"
1806                )?
1807            }
1808            DatumTypeDifference::Mismatch {
1809                got_debug,
1810                expected,
1811            } => {
1812                let expected = h.humanize_scalar_type(expected, false);
1813                // NB `got_debug` will be redacted as appropriate
1814                writeln!(
1815                    f,
1816                    "got datum {got_debug}, expected representation type {expected}"
1817                )?;
1818            }
1819            DatumTypeDifference::MismatchDimensions {
1820                ctor,
1821                got,
1822                expected,
1823            } => {
1824                writeln!(
1825                    f,
1826                    "{ctor} dimensions differ: got datum with dimension {got}, expected dimension {expected}"
1827                )?;
1828            }
1829            DatumTypeDifference::ElementType { ctor, element_type } => {
1830                writeln!(f, "{ctor} element types differ:")?;
1831                element_type.humanize(indent + 4, h, f)?;
1832            }
1833        }
1834
1835        Ok(())
1836    }
1837}
1838
1839/// Wrapper struct for a `Display` instance for `TypeError`s with a given `ExprHumanizer`
1840#[allow(missing_debug_implementations)]
1841pub struct TypeErrorHumanizer<'a, 'b, H>
1842where
1843    H: ExprHumanizer,
1844{
1845    err: &'a TypeError<'a>,
1846    humanizer: &'b H,
1847}
1848
1849impl<'a, 'b, H> TypeErrorHumanizer<'a, 'b, H>
1850where
1851    H: ExprHumanizer,
1852{
1853    /// Create a `Display`-shim struct for a given `TypeError`/`ExprHumanizer` pair
1854    pub fn new(err: &'a TypeError, humanizer: &'b H) -> Self {
1855        Self { err, humanizer }
1856    }
1857}
1858
1859impl<'a, 'b, H> std::fmt::Display for TypeErrorHumanizer<'a, 'b, H>
1860where
1861    H: ExprHumanizer,
1862{
1863    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1864        self.err.humanize(self.humanizer, f)
1865    }
1866}
1867
1868impl<'a> std::fmt::Display for TypeError<'a> {
1869    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1870        TypeErrorHumanizer {
1871            err: self,
1872            humanizer: &DummyHumanizer,
1873        }
1874        .fmt(f)
1875    }
1876}
1877
1878impl<'a> TypeError<'a> {
1879    /// The source of the type error
1880    pub fn source(&self) -> Option<&'a MirRelationExpr> {
1881        use TypeError::*;
1882        match self {
1883            Unbound { source, .. }
1884            | NoSuchColumn { source, .. }
1885            | MismatchColumn { source, .. }
1886            | MismatchColumns { source, .. }
1887            | BadConstantRowLen { source, .. }
1888            | BadConstantRow { source, .. }
1889            | BadProject { source, .. }
1890            | BadJoinEquivalence { source, .. }
1891            | BadTopKGroupKey { source, .. }
1892            | BadTopKOrdering { source, .. }
1893            | BadLetRecBindings { source }
1894            | Shadowing { source, .. }
1895            | DisallowedDummy { source, .. } => Some(source),
1896            Recursion { .. } => None,
1897        }
1898    }
1899
1900    fn humanize<H>(&self, humanizer: &H, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
1901    where
1902        H: ExprHumanizer,
1903    {
1904        if let Some(source) = self.source() {
1905            writeln!(f, "In the MIR term:\n{}\n", source.pretty())?;
1906        }
1907
1908        use TypeError::*;
1909        match self {
1910            Unbound { source: _, id, typ } => {
1911                let typ = columns_pretty(&typ.column_types, humanizer);
1912                writeln!(f, "{id} is unbound\ndeclared type {typ}")?
1913            }
1914            NoSuchColumn {
1915                source: _,
1916                expr,
1917                col,
1918            } => writeln!(f, "{expr} references non-existent column {col}")?,
1919            MismatchColumn {
1920                source: _,
1921                got,
1922                expected,
1923                diffs,
1924                message,
1925            } => {
1926                let got = humanizer.humanize_column_type(got, false);
1927                let expected = humanizer.humanize_column_type(expected, false);
1928                writeln!(
1929                    f,
1930                    "mismatched column types: {message}\n      got {got}\nexpected {expected}"
1931                )?;
1932
1933                for diff in diffs {
1934                    diff.humanize(2, humanizer, f)?;
1935                }
1936            }
1937            MismatchColumns {
1938                source: _,
1939                got,
1940                expected,
1941                diffs,
1942                message,
1943            } => {
1944                let got = columns_pretty(got, humanizer);
1945                let expected = columns_pretty(expected, humanizer);
1946
1947                writeln!(
1948                    f,
1949                    "mismatched relation types: {message}\n      got {got}\nexpected {expected}"
1950                )?;
1951
1952                for diff in diffs {
1953                    diff.humanize(humanizer, f)?;
1954                }
1955            }
1956            BadConstantRowLen {
1957                source: _,
1958                got,
1959                expected,
1960            } => {
1961                let expected = columns_pretty(expected, humanizer);
1962                writeln!(
1963                    f,
1964                    "bad constant row\n      row has length {got}\nexpected row of type {expected}"
1965                )?
1966            }
1967            BadConstantRow {
1968                source: _,
1969                mismatches,
1970                expected,
1971            } => {
1972                let expected = columns_pretty(expected, humanizer);
1973
1974                let num_mismatches = mismatches.len();
1975                let plural = if num_mismatches == 1 { "" } else { "es" };
1976                writeln!(
1977                    f,
1978                    "bad constant row\n      got {num_mismatches} mismatch{plural}\nexpected row of type {expected}"
1979                )?;
1980
1981                if num_mismatches > 0 {
1982                    writeln!(f, "")?;
1983                    for (col, diff) in mismatches.iter() {
1984                        writeln!(f, "      column #{col}:")?;
1985                        diff.humanize(8, humanizer, f)?;
1986                    }
1987                }
1988            }
1989            BadProject {
1990                source: _,
1991                got,
1992                input_type,
1993            } => {
1994                let input_type = columns_pretty(input_type, humanizer);
1995
1996                writeln!(
1997                    f,
1998                    "projection of non-existant columns {got:?} from type {input_type}"
1999                )?
2000            }
2001            BadJoinEquivalence {
2002                source: _,
2003                got,
2004                message,
2005            } => {
2006                let got = columns_pretty(got, humanizer);
2007
2008                writeln!(f, "bad join equivalence {got}: {message}")?
2009            }
2010            BadTopKGroupKey {
2011                source: _,
2012                k,
2013                input_type,
2014            } => {
2015                let input_type = columns_pretty(input_type, humanizer);
2016
2017                writeln!(
2018                    f,
2019                    "TopK group key component references invalid column {k} in columns: {input_type}"
2020                )?
2021            }
2022            BadTopKOrdering {
2023                source: _,
2024                order,
2025                input_type,
2026            } => {
2027                let col = order.column;
2028                let num_cols = input_type.len();
2029                let are = if num_cols == 1 { "is" } else { "are" };
2030                let s = if num_cols == 1 { "" } else { "s" };
2031                let input_type = columns_pretty(input_type, humanizer);
2032
2033                // TODO(cloud#8196)
2034                let mode = HumanizedExplain::new(false);
2035                let order = mode.expr(order, None);
2036
2037                writeln!(
2038                    f,
2039                    "TopK ordering {order} references invalid column {col}\nthere {are} {num_cols} column{s}: {input_type}"
2040                )?
2041            }
2042            BadLetRecBindings { source: _ } => {
2043                writeln!(f, "LetRec ids and definitions don't line up")?
2044            }
2045            Shadowing { source: _, id } => writeln!(f, "id {id} is shadowed")?,
2046            DisallowedDummy { source: _ } => writeln!(f, "contains a dummy value")?,
2047            Recursion { error } => writeln!(f, "{error}")?,
2048        }
2049
2050        Ok(())
2051    }
2052}
2053
2054#[cfg(test)]
2055mod tests {
2056    use mz_ore::{assert_err, assert_ok};
2057    use mz_repr::{SqlColumnType, arb_datum, arb_datum_for_column};
2058    use proptest::prelude::*;
2059
2060    use super::*;
2061
2062    #[mz_ore::test]
2063    fn test_datum_type_difference() {
2064        let datum = Datum::Int16(1);
2065
2066        assert_ok!(datum_difference_with_column_type(
2067            &datum,
2068            &ReprColumnType {
2069                scalar_type: ReprScalarType::Int16,
2070                nullable: true,
2071            }
2072        ));
2073
2074        assert_err!(datum_difference_with_column_type(
2075            &datum,
2076            &ReprColumnType {
2077                scalar_type: ReprScalarType::Int32,
2078                nullable: false,
2079            }
2080        ));
2081    }
2082
2083    proptest! {
2084        #![proptest_config(ProptestConfig {
2085            cases: 5000,
2086            max_global_rejects: 2500,
2087            ..Default::default()
2088        })]
2089        #[mz_ore::test]
2090        #[cfg_attr(miri, ignore)]
2091        fn datum_type_difference_with_instance_of_on_valid_data(
2092            (src, datum) in any::<SqlColumnType>()
2093                .prop_flat_map(|src| {
2094                    let datum = arb_datum_for_column(src.clone());
2095                    (Just(src), datum)
2096                })
2097        ) {
2098            let typ = ReprColumnType::from(&src);
2099            let datum = Datum::from(&datum);
2100
2101            if datum.contains_dummy() {
2102                return Err(TestCaseError::reject("datum contains a dummy"));
2103            }
2104
2105            let diff = datum_difference_with_column_type(&datum, &typ);
2106            if datum.is_instance_of(&typ) {
2107                assert_ok!(diff);
2108            } else {
2109                assert_err!(diff);
2110            }
2111        }
2112    }
2113
2114    proptest! {
2115        // We run many cases because the data are _random_, and we want to be sure
2116        // that we have covered sufficient cases. We drop dummy data (behavior is different!) so we allow many more global rejects.
2117        #![proptest_config(ProptestConfig::with_cases(10000))]
2118        #[mz_ore::test]
2119        #[cfg_attr(miri, ignore)]
2120        fn datum_type_difference_agrees_with_is_instance_of_on_random_data(
2121            src in any::<SqlColumnType>(),
2122            datum in arb_datum(false),
2123        ) {
2124            let typ = ReprColumnType::from(&src);
2125            let datum = Datum::from(&datum);
2126
2127            assert!(!datum.contains_dummy(), "datum contains a dummy (bug in arb_datum)");
2128
2129            let diff = datum_difference_with_column_type(&datum, &typ);
2130            if datum.is_instance_of(&typ) {
2131                assert_ok!(diff);
2132            } else {
2133                assert_err!(diff);
2134            }
2135        }
2136    }
2137
2138    #[mz_ore::test]
2139    fn datum_type_difference_github_10039() {
2140        let typ = ReprColumnType {
2141            scalar_type: ReprScalarType::Record {
2142                fields: Box::new([ReprColumnType {
2143                    scalar_type: ReprScalarType::UInt32,
2144                    nullable: false,
2145                }]),
2146            },
2147            nullable: false,
2148        };
2149
2150        let mut row = mz_repr::Row::default();
2151        row.packer()
2152            .push_list(std::iter::once(mz_repr::Datum::Null));
2153        let datum = row.unpack_first();
2154
2155        assert!(!datum.is_instance_of(&typ));
2156        let diff = datum_difference_with_column_type(&datum, &typ);
2157        assert_err!(diff);
2158    }
2159}